Skip to content

Commit e487fc9

Browse files
davidberard98liuyunqi20
authored andcommitted
[BACKEND] Make ExternElementwise op implement ConditionallySpeculatable (#5079)
ExternElementwise ops have a `pure` attribute that marks the op as pure. If an op is pure, it should also be speculatable. In the reduction/scan ttgir->llvm passes, checks for speculatability are failing for ExternElementwise ops, causing additional conditional handling to be added. This PR makes ExternElementwise ops implement ConditionallySpeculatable, and mark the op as speculatable if the op is marked as pure. This removes the conditional branches from the generated scan/reduction code.
1 parent 158473c commit e487fc9

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
776776
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
777777
SameOperandsAndResultEncoding,
778778
SameVariadicOperandSize,
779-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
779+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
780+
ConditionallySpeculatable]> {
780781

781782
let description = [{
782783
call an external function $symbol implemented in $libpath/$libname with $args
@@ -788,6 +789,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
788789
let results = (outs TT_Type:$result);
789790

790791
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
792+
793+
let extraClassDeclaration = [{
794+
// Interface method for ConditionallySpeculatable.
795+
Speculation::Speculatability getSpeculatability();
796+
}];
797+
791798
}
792799

793800
//

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,12 @@ void ExternElementwiseOp::getEffects(
10161016
SideEffects::DefaultResource::get());
10171017
}
10181018

1019+
Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
1020+
if (getPure())
1021+
return Speculation::Speculatable;
1022+
return Speculation::NotSpeculatable;
1023+
}
1024+
10191025
// -- ExperimentalTensormapCreateOp --
10201026
LogicalResult ExperimentalTensormapCreateOp::verify() {
10211027
auto rank = getBoxDim().size();

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,3 +1741,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
17411741
#loc3 = loc("inner_call":29:28)
17421742
#loc4 = loc(callsite(#loc3 at #loc1))
17431743
#loc5 = loc(callsite(#loc4 at #loc2))
1744+
1745+
// -----
1746+
1747+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
1748+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
1749+
tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} {
1750+
// CHECK: log1pf_scan
1751+
// non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable.
1752+
// CHECK-NOT: llvm.cond_br
1753+
%40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({
1754+
^bb0(%arg5: f32, %arg6: f32):
1755+
%43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32
1756+
%44 = arith.addf %43, %43 : f32
1757+
tt.scan.return %44 : f32
1758+
}) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked>
1759+
tt.return
1760+
}
1761+
}

0 commit comments

Comments
 (0)