Skip to content

Commit 77ba5d7

Browse files
authored
[AMD] Enable ds_read_tr for fp4 packed along K (#7481)
This was defensively disabled in a previous commit but has been verified to work fine. FP4 when packed along K dimension needs to use ds_read_tr8 when loaded from shared memory and transpose is needed. This is because packing needs to stay the same so we need to operate on FP4 as if they were i8 types, this way we don't change the packing order. Note: the LIT test that I've added is to show what the previous behaviour was in comparison to current. The code was explicitly checking dot_scaled usage so I've written the test to show the new behaviour based on that. Although new behaviour doesn't need to look at dot_scaled anymore.
1 parent c944014 commit 77ba5d7

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

test/Conversion/amd/ds_transpose.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,4 +367,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
367367
tt.return
368368
}
369369

370+
// CHECK-LABEL: ds_transpose_fp4_mfma_32
371+
tt.func @ds_transpose_fp4_mfma_32(%arg0: !ttg.memdesc<128x128xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<128x128xi8, #shared1, #smem, mutable>, %arg2: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>) {
372+
// CHECK-COUNT-32: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
373+
// CHECK-NOT: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
374+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x128xi8, #shared, #smem, mutable> -> tensor<128x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
375+
%2 = ttg.local_load %arg1 : !ttg.memdesc<128x128xi8, #shared1, #smem, mutable> -> tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
376+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma32>
377+
%3 = tt.dot_scaled %1, %2, %cst_2 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> * tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> -> tensor<128x128xf32, #mma32>
378+
ttg.local_store %3, %arg2 : tensor<128x128xf32, #mma32> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>
379+
tt.return
380+
}
370381
}

third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,9 @@ struct TransLocalLoadOpConversion
213213
auto bitwidth = typeConverter->convertType(dstTy.getElementType())
214214
.getIntOrFloatBitWidth();
215215

216-
// Triton does not natively support the FP4 type, so it is packed and
217-
// represented as an i8. Currently, the only way to distinguish FP4 from an
218-
// actual int8 is by checking whether the localLoad is used in a scaled dot
219-
// operation, as int8 is never used in one.
220-
bool isFP4 = isUsedByDotScaledOp(localLoad) && bitwidth == 8 &&
221-
dstTy.getElementType().isInteger();
222-
223-
if (isFP4 || (bitwidth != 16 && bitwidth != 8)) {
216+
// FP4 is represented as i8 and, when packed along K, can be
217+
// transposed using ds_read_tr8 which doesn't change packing.
218+
if (bitwidth != 16 && bitwidth != 8) {
224219
return false;
225220
}
226221

0 commit comments

Comments
 (0)