Skip to content

Commit 781774c

Browse files
[BACKEND] Make ExternElementwise op implement ConditionallySpeculatable (#5079)
ExternElementwise ops have a `pure` attribute that marks the op as pure. If an op is pure, it should also be speculatable. In the reduction/scan ttgir->llvm passes, checks for speculatability are failing for ExternElementwise ops, causing additional conditional handling to be added. This PR makes ExternElementwise ops implement ConditionallySpeculatable, and mark the op as speculatable if the op is marked as pure. This removes the conditional branches from the generated scan/reduction code.
1 parent 03968e6 commit 781774c

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
778778
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
779779
SameOperandsAndResultEncoding,
780780
SameVariadicOperandSize,
781-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
781+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
782+
ConditionallySpeculatable]> {
782783

783784
let description = [{
784785
call an external function $symbol implemented in $libpath/$libname with $args
@@ -790,6 +791,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
790791
let results = (outs TT_Type:$result);
791792

792793
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
794+
795+
let extraClassDeclaration = [{
796+
// Interface method for ConditionallySpeculatable.
797+
Speculation::Speculatability getSpeculatability();
798+
}];
799+
793800
}
794801

795802
//

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,12 @@ void ExternElementwiseOp::getEffects(
10391039
SideEffects::DefaultResource::get());
10401040
}
10411041

1042+
Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
1043+
if (getPure())
1044+
return Speculation::Speculatable;
1045+
return Speculation::NotSpeculatable;
1046+
}
1047+
10421048
// -- ExperimentalTensormapCreateOp --
10431049
LogicalResult ExperimentalTensormapCreateOp::verify() {
10441050
auto rank = getBoxDim().size();

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,3 +1855,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
18551855
#loc3 = loc("inner_call":29:28)
18561856
#loc4 = loc(callsite(#loc3 at #loc1))
18571857
#loc5 = loc(callsite(#loc4 at #loc2))
1858+
1859+
// -----
1860+
1861+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
1862+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
1863+
tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} {
1864+
// CHECK: log1pf_scan
1865+
// non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable.
1866+
// CHECK-NOT: llvm.cond_br
1867+
%40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({
1868+
^bb0(%arg5: f32, %arg6: f32):
1869+
%43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32
1870+
%44 = arith.addf %43, %43 : f32
1871+
tt.scan.return %44 : f32
1872+
}) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked>
1873+
tt.return
1874+
}
1875+
}

0 commit comments

Comments
 (0)