Skip to content

Commit a1acb06

Browse files
borontionanmyachev
authored andcommitted
[AMD] Make kWidth to mandatory for WMMA v3 (#8783)
Currently we limit WMMA v3's kWidth to be {2, 8, 16} which matches the hardware view for all possible WMMA instructions. In the case of wmma_scaled, we assume kWidth always to be 16. But in attention kernel, we can use kWidth = 8 which will remove the layout convert between 2 dots. This does not match the hardware view for continuous elements from k dimension, but we can still get correct results unless the kWidth for 2 operands are the same. This PR removes the kWidth check for WMMA v3 and makes it mandatory, same as MFMA.
1 parent 7f153c3 commit a1acb06

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,9 +2505,9 @@ LogicalResult DotOperandEncodingAttr::verify(
25052505
return emitError()
25062506
<< "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 "
25072507
"(including packed cases for `scaled_dot`)";
2508-
if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth))
2508+
if (parentAttr.getVersion() == 3 && kWidth == 0)
25092509
return emitError()
2510-
<< "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3";
2510+
<< "ttg.dot_op kWidth parameter is mandatory for WMMA v3 ";
25112511
return success();
25122512
}
25132513

test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
200200
tt.return
201201
}
202202
}
203+
204+
// -----
205+
206+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [32, 0]], block = []}>
207+
#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [0, 0]], block = []}>
208+
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [4, 1], instrShape=[16, 16, 128]}>
209+
210+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
211+
// CHECK-LABEL: wmma_scaled_dot_fp8_chained
212+
tt.func @wmma_scaled_dot_fp8_chained(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg2: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %out0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
213+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
214+
%scale0 = arith.constant dense<127> : tensor<128x4xi8, #linear>
215+
%scale1 = arith.constant dense<127> : tensor<128x4xi8, #linear1>
216+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
217+
%mm0 = tt.dot_scaled %arg0 scale %scale0, %arg2 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma>
218+
// CHECK-NOT: rocdl.ds_swizzle
219+
// CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
220+
%op0 = ttg.convert_layout %mm0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
221+
%op1 = tt.fp_to_fp %op0, rounding = rtne : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
222+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
223+
%mm1 = tt.dot_scaled %op1 scale %scale0, %arg3 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma>
224+
%ptr0 = tt.splat %out0 : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #mma>
225+
tt.store %ptr0, %mm1 : tensor<128x128x!tt.ptr<f32>, #mma>
226+
tt.return
227+
}
228+
}

0 commit comments

Comments
 (0)