Skip to content

Commit 37e372c

Browse files
authored
[Backend] Make dot_scaled as anchor in RemoveLayoutConversion (triton-lang#6263)
This PR added DotScaledOp as anchor op to the RemoveLayoutConversion pass to eliminate `ttg.convert_layout`.
1 parent e4baf45 commit 37e372c

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ void LayoutRematerialization::cleanup() {
184184
bool isLayoutAnchor(Operation *op) {
185185
if (isa<LoadOp, StoreOp>(op))
186186
return isExpensiveLoadOrStore(op);
187-
if (isa<DotOp, nvidia_gpu::WarpGroupDotOp, AtomicRMWOp, AtomicCASOp,
188-
triton::nvidia_gpu::TMEMLoadOp>(op))
187+
if (isa<DotOp, DotScaledOp, nvidia_gpu::WarpGroupDotOp, AtomicRMWOp,
188+
AtomicCASOp, triton::nvidia_gpu::TMEMLoadOp>(op))
189189
return true;
190190
if (auto gatherOp = dyn_cast<GatherOp>(op))
191191
return gatherOp.getEfficientLayout();

test/TritonGPU/combine.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3679,3 +3679,44 @@ module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
36793679
tt.return %1 : tensor<2x16x2xf32, #blocked>
36803680
}
36813681
}
3682+
3683+
// -----
3684+
3685+
#linear = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}>
3686+
#linear1 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}>
3687+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
3688+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
3689+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
3690+
#dot_op_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>
3691+
#dot_op_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>
3692+
// CHECK: [[$BLOCK:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
3693+
// CHECK-LABEL: mfma_dot_scaled_no_redundant_convert_layout
3694+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
3695+
tt.func public @mfma_dot_scaled_no_redundant_convert_layout(
3696+
%arg0: tensor<128x128xf8E4M3FN, #dot_op_a>,
3697+
%arg1: tensor<128x128xf8E4M3FN, #dot_op_b>,
3698+
%arg2: tensor<128x4xi8, #linear>,
3699+
%arg3: tensor<128x4xi8, #linear1>,
3700+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
3701+
) {
3702+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
3703+
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
3704+
%c0 = arith.constant 0 : index
3705+
%c1 = arith.constant 1 : index
3706+
%c32 = arith.constant 32 : index
3707+
// CHECK: %[[RET:.+]] = scf.for
3708+
// CHECK-NEXT: %[[DOT_RET:.+]] = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false}
3709+
// CHECK-NEXT: scf.yield %[[DOT_RET]]
3710+
// CHECK-NEXT: }
3711+
// CHECK-NEXT: ttg.convert_layout %[[RET]] : tensor<128x128xf32, #mma> -> tensor<128x128xf32, [[$BLOCK]]>
3712+
// CHECK-NEXT: tt.store
3713+
%1 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst0) -> (tensor<128x128xf32, #blocked1>) {
3714+
%4 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E4M3FN, #dot_op_a>, tensor<128x4xi8, #linear> * tensor<128x128xf8E4M3FN, #dot_op_b>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma>
3715+
%5 = ttg.convert_layout %4 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked1>
3716+
scf.yield %5 : tensor<128x128xf32, #blocked1>
3717+
}
3718+
%7 = ttg.convert_layout %1 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked>
3719+
tt.store %arg4, %7 : tensor<128x128x!tt.ptr<f32>, #blocked>
3720+
tt.return
3721+
}
3722+
}

0 commit comments

Comments
 (0)