Skip to content

Commit 4bcdbde

Browse files
authored
[AMD] Emit shared memory ops for M/N packed FP4 (#7626)
This adds support for FP4 packed along M/N. MFMA only support K packed inputs so we need to transpose the inputs from M/N packed M/N contiguous to K packed K contiguous. This is achieved by changing the contiguity of the tensor (as Triton is providing K contiguous data) when storing it into shared memory and then transposing the data using the LocalLoadPackedTransposedOp. In order to keep the scope of the change as localised as possible, this can be achieved directly in AccelerateAMDMatmul. This change achieves two different things: - Store the input tensor in shared memory in M/N contiguous way (swapping the shared layout order) - Transpose the tensor using ds_read_b64_tr4
1 parent 3ee52d3 commit 4bcdbde

File tree

3 files changed

+242
-6
lines changed

3 files changed

+242
-6
lines changed

test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
223223
tt.return
224224
}
225225
}
226+
227+
// -----
228+
229+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
230+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
231+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
232+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_b_packed_mn
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
234+
tt.func public @mfma_dot_scaled_mxfp4_b_packed_mn(
235+
%a: tensor<128x128xf8E5M2, #blocked>,
236+
%b: tensor<128x64xi8, #blocked1>,
237+
%c: tensor<128x128xf32, #blocked>,
238+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
239+
) {
240+
%b1 = ttg.convert_layout %b : tensor<128x64xi8, #blocked1> -> tensor<128x64xi8, #blocked>
241+
// CHECK: %[[ALLOCB:.+]] = ttg.local_alloc {{.*}} : (tensor<128x64xi8, #blocked>) -> !ttg.memdesc<128x64xi8, #shared, #smem>
242+
// CHECK: %[[B:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCB]] : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
243+
// CHECK: tt.dot_scaled %{{.*}}, %[[B]], %{{.*}} lhs = e5m2 rhs = e2m1 {fastMath = false}
244+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e5m2 rhs = e2m1 {fastMath = false, rhs_k_pack = false} : tensor<128x128xf8E5M2, #blocked> * tensor<128x64xi8, #blocked> -> tensor<128x128xf32, #blocked>
245+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
246+
tt.return
247+
}
248+
}
249+
// -----
250+
251+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
252+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
253+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
254+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_a_packed_mn
255+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
256+
tt.func public @mfma_dot_scaled_mxfp4_a_packed_mn(
257+
%a: tensor<64x128xi8, #blocked>,
258+
%b: tensor<128x128xf8E5M2, #blocked1>,
259+
%c: tensor<128x128xf32, #blocked>,
260+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
261+
) {
262+
%b1 = ttg.convert_layout %b : tensor<128x128xf8E5M2, #blocked1> -> tensor<128x128xf8E5M2, #blocked>
263+
// CHECK: %[[ALLOCA:.+]] = ttg.local_alloc {{.*}} : (tensor<64x128xi8, #blocked>) -> !ttg.memdesc<64x128xi8, #shared, #smem>
264+
// CHECK: %[[A:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCA]] : !ttg.memdesc<64x128xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
265+
// CHECK: tt.dot_scaled %[[A]], %{{.*}}, %{{.*}} lhs = e2m1 rhs = e5m2 {fastMath = false}
266+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e2m1 rhs = e5m2 {fastMath = false, lhs_k_pack = false} : tensor<64x128xi8, #blocked> * tensor<128x128xf8E5M2, #blocked> -> tensor<128x128xf32, #blocked>
267+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
268+
tt.return
269+
}
270+
}
271+
272+
// -----
273+
274+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
275+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
276+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
277+
// CHECK{LITERAL}: #shared1 = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
278+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_ab_packed_mn
279+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
280+
tt.func public @mfma_dot_scaled_mxfp4_ab_packed_mn(
281+
%a: tensor<64x128xi8, #blocked>,
282+
%b: tensor<128x64xi8, #blocked1>,
283+
%c: tensor<128x128xf32, #blocked>,
284+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
285+
) {
286+
%b1 = ttg.convert_layout %b : tensor<128x64xi8, #blocked1> -> tensor<128x64xi8, #blocked>
287+
// CHECK: %[[ALLOCA:.+]] = ttg.local_alloc {{.*}} : (tensor<64x128xi8, #blocked>) -> !ttg.memdesc<64x128xi8, #shared, #smem>
288+
// CHECK: %[[A:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCA]] : !ttg.memdesc<64x128xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
289+
// CHECK: %[[ALLOCB:.+]] = ttg.local_alloc {{.*}} : (tensor<128x64xi8, #blocked>) -> !ttg.memdesc<128x64xi8, #shared1, #smem>
290+
// CHECK: %[[B:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCB]] : !ttg.memdesc<128x64xi8, #shared1, #smem> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
291+
// CHECK: tt.dot_scaled %[[A]], %[[B]], %{{.*}} lhs = e2m1 rhs = e2m1 {fastMath = false}
292+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e2m1 rhs = e2m1 {fastMath = false, lhs_k_pack = false, rhs_k_pack = false} : tensor<64x128xi8, #blocked> * tensor<128x64xi8, #blocked> -> tensor<128x128xf32, #blocked>
293+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
294+
tt.return
295+
}
296+
}

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,123 @@ tt.func @pipeline_fp64_with_async_copy_gfx950(
797797
tt.return %loop: tensor<128x128xf64, #C>
798798
}
799799
}
800+
801+
// -----
802+
803+
// COMMON-LABEL: pipelining_local_load_packed_transposed
804+
805+
// Prologue
806+
// COMMON: ttg.local_alloc
807+
// COMMON: ttg.local_alloc
808+
// ASYNC: ttg.async_copy_global_to_local
809+
// SYNC: tt.load
810+
// COMMON: tt.load
811+
// SYNC: ttg.local_store
812+
// COMMON: ttg.local_store
813+
814+
// Main loop
815+
// COMMON: scf.for
816+
// COMMON: ttg.local_load
817+
// COMMON: amdgpu.local_load_packed_tranposed
818+
// COMMON: tt.dot_scaled
819+
// COMMON: scf.yield
820+
821+
// Epilogue
822+
// COMMON: ttg.local_load
823+
// COMMON: amdgpu.local_load_packed_tranposed
824+
// COMMON: scf.if
825+
// COMMON: tt.dot_scaled
826+
// COMMON-COUNT-2: scf.yield
827+
// COMMON-COUNT-2: ttg.local_dealloc
828+
829+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
830+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
831+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
832+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
833+
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
834+
#smem = #ttg.shared_memory
835+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
836+
tt.func public @pipelining_local_load_packed_transposed(%a_ptr: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %b_ptr: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %output_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_scale: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bn: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
837+
%cst = arith.constant dense<128> : tensor<128x128xi32, #blocked>
838+
%cst_0 = arith.constant dense<128> : tensor<128x64xi32, #blocked1>
839+
%c0_i32 = arith.constant 0 : i32
840+
%c1_i32 = arith.constant 1 : i32
841+
%c127_i32 = arith.constant 127 : i32
842+
%c128_i32 = arith.constant 128 : i32
843+
%c2_i32 = arith.constant 2 : i32
844+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
845+
%0 = tt.get_program_id x : i32
846+
%1 = arith.addi %M, %c127_i32 : i32
847+
%2 = arith.divsi %1, %c128_i32 : i32
848+
%3 = arith.remsi %0, %2 : i32
849+
%4 = arith.divsi %0, %2 : i32
850+
%5 = arith.muli %3, %c128_i32 : i32
851+
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
852+
%7 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
853+
%8 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
854+
%9 = tt.splat %5 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
855+
%10 = tt.splat %5 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
856+
%11 = arith.addi %9, %6 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
857+
%12 = arith.addi %10, %7 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
858+
%13 = arith.muli %4, %c128_i32 : i32
859+
%14 = arith.divsi %13, %c2_i32 : i32
860+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
861+
%16 = tt.splat %14 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
862+
%17 = arith.addi %16, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
863+
%18 = tt.expand_dims %11 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
864+
%19 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2>
865+
%20 = tt.splat %stride_am : i32 -> tensor<128x1xi32, #blocked>
866+
%21 = arith.muli %18, %20 : tensor<128x1xi32, #blocked>
867+
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
868+
%23 = tt.expand_dims %22 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
869+
%24 = tt.broadcast %21 : tensor<128x1xi32, #blocked> -> tensor<128x128xi32, #blocked>
870+
%25 = tt.broadcast %23 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked>
871+
%26 = arith.addi %24, %25 : tensor<128x128xi32, #blocked>
872+
%27 = tt.splat %a_ptr : !tt.ptr<f8E5M2> -> tensor<128x128x!tt.ptr<f8E5M2>, #blocked>
873+
%28 = tt.addptr %27, %26 : tensor<128x128x!tt.ptr<f8E5M2>, #blocked>, tensor<128x128xi32, #blocked>
874+
%29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
875+
%30 = tt.expand_dims %29 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
876+
%31 = tt.expand_dims %17 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
877+
%32 = tt.splat %stride_bn : i32 -> tensor<1x64xi32, #blocked1>
878+
%33 = arith.muli %31, %32 : tensor<1x64xi32, #blocked1>
879+
%34 = tt.broadcast %30 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
880+
%35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
881+
%36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1>
882+
%37 = tt.splat %b_ptr : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked1>
883+
%38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr<i8>, #blocked1>, tensor<128x64xi32, #blocked1>
884+
%39 = arith.addi %K, %c127_i32 : i32
885+
%40 = arith.divsi %39, %c128_i32 : i32
886+
%accumulator:3 = scf.for %accumulator_2 = %c0_i32 to %40 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %28, %arg13 = %38) -> (tensor<128x128xf32, #mma>, tensor<128x128x!tt.ptr<f8E5M2>, #blocked>, tensor<128x64x!tt.ptr<i8>, #blocked1>) : i32 {
887+
%60 = tt.load %arg12 : tensor<128x128x!tt.ptr<f8E5M2>, #blocked>
888+
%61 = tt.load %arg13 : tensor<128x64x!tt.ptr<i8>, #blocked1>
889+
%62 = ttg.convert_layout %60 : tensor<128x128xf8E5M2, #blocked> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
890+
%63 = ttg.local_alloc %61 : (tensor<128x64xi8, #blocked1>) -> !ttg.memdesc<128x64xi8, #shared, #smem>
891+
%64 = amdgpu.local_load_packed_tranposed %63 : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
892+
%65 = tt.dot_scaled %62, %64, %arg11 lhs = e5m2 rhs = e2m1 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<128x128xf32, #mma>
893+
%66 = tt.addptr %arg12, %cst : tensor<128x128x!tt.ptr<f8E5M2>, #blocked>, tensor<128x128xi32, #blocked>
894+
%67 = tt.addptr %arg13, %cst_0 : tensor<128x64x!tt.ptr<i8>, #blocked1>, tensor<128x64xi32, #blocked1>
895+
scf.yield %65, %66, %67 : tensor<128x128xf32, #mma>, tensor<128x128x!tt.ptr<f8E5M2>, #blocked>, tensor<128x64x!tt.ptr<i8>, #blocked1>
896+
} {tt.num_stages = 2 : i32}
897+
%41 = tt.splat %13 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
898+
%42 = arith.addi %41, %8 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
899+
%43 = tt.splat %stride_cm : i32 -> tensor<128x1xi32, #blocked2>
900+
%44 = arith.muli %43, %19 : tensor<128x1xi32, #blocked2>
901+
%45 = tt.splat %output_ptr : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked2>
902+
%46 = tt.addptr %45, %44 : tensor<128x1x!tt.ptr<f32>, #blocked2>, tensor<128x1xi32, #blocked2>
903+
%47 = tt.expand_dims %42 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi32, #blocked2>
904+
%48 = tt.broadcast %46 : tensor<128x1x!tt.ptr<f32>, #blocked2> -> tensor<128x128x!tt.ptr<f32>, #blocked2>
905+
%49 = tt.broadcast %47 : tensor<1x128xi32, #blocked2> -> tensor<128x128xi32, #blocked2>
906+
%50 = tt.addptr %48, %49 : tensor<128x128x!tt.ptr<f32>, #blocked2>, tensor<128x128xi32, #blocked2>
907+
%51 = tt.splat %M : i32 -> tensor<128x1xi32, #blocked2>
908+
%52 = arith.cmpi slt, %19, %51 : tensor<128x1xi32, #blocked2>
909+
%53 = tt.splat %N : i32 -> tensor<1x128xi32, #blocked2>
910+
%54 = arith.cmpi slt, %47, %53 : tensor<1x128xi32, #blocked2>
911+
%55 = tt.broadcast %52 : tensor<128x1xi1, #blocked2> -> tensor<128x128xi1, #blocked2>
912+
%56 = tt.broadcast %54 : tensor<1x128xi1, #blocked2> -> tensor<128x128xi1, #blocked2>
913+
%57 = arith.andi %55, %56 : tensor<128x128xi1, #blocked2>
914+
%58 = ttg.convert_layout %50 : tensor<128x128x!tt.ptr<f32>, #blocked2> -> tensor<128x128x!tt.ptr<f32>, #mma>
915+
%59 = ttg.convert_layout %57 : tensor<128x128xi1, #blocked2> -> tensor<128x128xi1, #mma>
916+
tt.store %58, %accumulator#0, %59 : tensor<128x128x!tt.ptr<f32>, #mma>
917+
tt.return
918+
}
919+
}

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
209209
int mfmaVersion, int nonKDim) {
210210
auto ctx = dot.getContext();
211211
int64_t inputKDim = dot.getA().getType().getShape().back();
212-
if (dot.getAElemType() == ScaleDotElemType::E2M1) {
212+
if (dot.getAElemType() == ScaleDotElemType::E2M1 && dot.getLhsKPack()) {
213213
// Since two fp4 are packed into int8, to get the correct K dim size, we
214214
// need to multiply it by 2.
215215
inputKDim *= 2;
@@ -928,11 +928,56 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
928928
auto newEnc =
929929
DotOperandEncodingAttr::get(ctx, opIdx, mfmaEnc, kWidth / 2);
930930

931-
(opIdx == 0 ? aEncLL : bEncLL) *=
932-
newEnc.toLinearLayout(opIdx == 0 ? aShape : bShape);
933-
auto newVType = RankedTensorType::get(vType.getShape(),
934-
vType.getElementType(), newEnc);
935-
return rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
931+
bool kPacked = opIdx == 0 ? dotOp.getLhsKPack() : dotOp.getRhsKPack();
932+
if (kPacked == false) {
933+
// This is FP4 with M/N packing. Create local alloc + local load here
934+
// so we have control of the shared layout
935+
// A, M packed: tensor<16x64xi8> --> 32x32
936+
// B, N packed: tensor<64x16xi8> --> 32x32
937+
SmallVector<int64_t> newShape(vType.getShape());
938+
newShape[opIdx == 0 ? 0 : 1] = newShape[opIdx == 0 ? 0 : 1] * 2;
939+
newShape[opIdx == 0 ? 1 : 0] = newShape[opIdx == 0 ? 1 : 0] / 2;
940+
auto newVType =
941+
RankedTensorType::get(newShape, vType.getElementType(), newEnc);
942+
OpBuilder builder(dotOp);
943+
auto srcEncoding = vType.getEncoding();
944+
auto originalOrder = triton::gpu::getOrderForMemory(vType);
945+
SmallVector<unsigned> newOrder = originalOrder;
946+
if (opIdx == 1) {
947+
newOrder = {1, 0};
948+
} else {
949+
newOrder = {0, 1};
950+
}
951+
auto sharedMemorySpace =
952+
triton::gpu::SharedMemorySpaceAttr::get(vType.getContext());
953+
auto tmpType = triton::gpu::MemDescType::get(
954+
vType.getShape(), vType.getElementType(),
955+
triton::gpu::SwizzledSharedEncodingAttr::get(
956+
v.getContext(), newEnc, vType.getShape(), newOrder,
957+
triton::gpu::getCTALayout(srcEncoding), vType.getElementType()),
958+
sharedMemorySpace);
959+
auto tmp = builder.create<triton::gpu::LocalAllocOp>(dotOp.getLoc(),
960+
tmpType, v);
961+
auto newConvert =
962+
builder.create<triton::amdgpu::LocalLoadPackedTransposedOp>(
963+
dotOp.getLoc(), newVType, tmp);
964+
if (opIdx == 0) {
965+
aShape = newConvert.getType().getShape();
966+
aEncLL *= newEnc.toLinearLayout(aShape);
967+
} else {
968+
bShape = newConvert.getType().getShape();
969+
bEncLL *= newEnc.toLinearLayout(bShape);
970+
}
971+
return newConvert;
972+
} else {
973+
if (opIdx == 0)
974+
aEncLL *= newEnc.toLinearLayout(aShape);
975+
else
976+
bEncLL *= newEnc.toLinearLayout(bShape);
977+
auto newVType = RankedTensorType::get(vType.getShape(),
978+
vType.getElementType(), newEnc);
979+
return rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
980+
}
936981
};
937982
a = convertInputLayout(a, 0);
938983
b = convertInputLayout(b, 1);

0 commit comments

Comments
 (0)