Skip to content

Commit 1444755

Browse files
authored
[LLVMGPU] Add VMFMA for FP8 to align layouts between chained F8 contractions. (#19020)
This PR introduces virtual intrinsics on F8 MFMA that breaks apart a single 8xF8 read into two interleaved 4xF8 read from shared memory. This main motivation for this virtual intrinsic is to enable faster F8 attention/chained matmuls. The reason for that is by doing interleaved reads on K-dimension, we can match the native F8 intrisic output layout coming from the 1st matmul to the rhs read of the 2nd matmul(with interleaved virtual MFMA layout). Once the layout is aligned, we just need to handle it using to_layout lowering that does reshape on the SIMT vector. This PR has been tested on attention of shape: [B: 1, M: 4096, K1: 64, K2: 4096, N: 64] as seen in this IR: [(link)](https://gist.githubusercontent.com/raikonenfnu/4d33b5addfa9c4ec9e76918704251e39/raw/5b20c0c359e3e2df7f8db4890d3cc0590352d18a/attention_f8_perf.mlir) and using this spec to specify the VMFMA on 2nd matmul and regular MFMA on 1st matmul: ([link](https://gist.githubusercontent.com/raikonenfnu/4d33b5addfa9c4ec9e76918704251e39/raw/5b20c0c359e3e2df7f8db4890d3cc0590352d18a/attn_config.mlir)) we were able to get perf of 1.63x speed up from (reference with same config but using MFMA_16x16x32xF16 on both contractions. With correct/same numerics. Signed-off-by: Stanley Winata <[email protected]>
1 parent f71dd12 commit 1444755

File tree

4 files changed

+95
-1
lines changed

4 files changed

+95
-1
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
236236
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
237237
return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
238238
}
239-
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
239+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
240+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
240241
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
241242
}
242243
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
@@ -420,6 +421,7 @@ MMAAttr::getABCVectorTypes() const {
420421
}
421422
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
422423
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
424+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
423425
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
424426
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
425427
auto aType = VectorType::get({8}, getAType());
@@ -471,6 +473,7 @@ int64_t MMAAttr::getBlockSize() const {
471473
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
472474
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
473475
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
476+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
474477
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
475478
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
476479
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -496,6 +499,7 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
496499
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
497500
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
498501
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
502+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
499503
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
500504
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
501505
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -578,6 +582,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
578582
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
579583
/*element=*/{4, 1}};
580584
}
585+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
586+
switch (fragment) {
587+
case MMAFragment::Lhs:
588+
return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
589+
/*element=*/{1, 4}};
590+
case MMAFragment::Rhs:
591+
return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
592+
/*element=*/{4, 1}};
593+
case MMAFragment::Acc:
594+
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
595+
/*element=*/{4, 1}};
596+
}
581597
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
582598
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
583599
switch (fragment) {
@@ -711,6 +727,7 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
711727
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
712728
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
713729
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
730+
case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
714731
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
715732
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
716733
auto [m, n, k] = getMNKShape();

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
133133
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
134134
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
135135
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
136+
// V-Intrinsic below interleaves read from K-dim from one 8xF8 to two 4xF8.
137+
// (Useful in F8 chained-MM to align B-layout of 2nd MM to C-layout of 1st MM)
138+
def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 0x0941>;
136139
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
137140
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>;
138141

@@ -159,6 +162,7 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
159162
MFMA_F32_32x32x8_BF16,
160163
MFMA_F32_16x16x32_F8E4M3FNUZ,
161164
MFMA_F32_16x16x32_F8E5M2FNUZ,
165+
VMFMA_F32_16x16x32_F8E4M3FNUZ,
162166
MFMA_I32_16x16x32_I8,
163167
MFMA_I32_32x32x16_I8,
164168
MFMA_I32_16x16x16_I8,

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,76 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
718718

719719
// -----
720720

721+
// This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts.
722+
723+
#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<VMFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2}>
724+
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
725+
726+
#pipeline_layout = #hal.pipeline.layout<bindings = [
727+
#hal.pipeline.binding<storage_buffer>,
728+
#hal.pipeline.binding<storage_buffer>,
729+
#hal.pipeline.binding<storage_buffer>
730+
]>
731+
hal.executable @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 {
732+
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
733+
hal.executable.export @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 layout(#pipeline_layout) {
734+
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
735+
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
736+
hal.return %x, %y, %z : index, index, index
737+
}
738+
builtin.module {
739+
func.func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32() attributes {translation_info = #translation} {
740+
%cst = arith.constant 0.000000e+00 : f32
741+
%c0 = arith.constant 0 : index
742+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
743+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
744+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
745+
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
746+
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
747+
%5 = tensor.empty() : tensor<256x256xf32>
748+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
749+
%7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
750+
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
751+
return
752+
}
753+
}
754+
}
755+
}
756+
757+
// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32
758+
// CHECK-DAG: %[[ALLOC_LHS:.+]] = memref.alloc() : memref<32x136xf8E4M3FNUZ, #gpu.address_space<workgroup>>
759+
// CHECK-DAG: %[[ALLOC_RHS:.+]] = memref.alloc() : memref<128x40xf8E4M3FNUZ, #gpu.address_space<workgroup>>
760+
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x1x1x4x1xf32>)
761+
762+
// Validate that VMFMA do 2 interleaved reads, combine them for every MFMA:
763+
764+
// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_LHS]]
765+
// CHECK: %[[SLICE_LHS_0:.+]] = vector.transfer_read %[[ALLOC_LHS]]
766+
// CHECK: %[[VECTOR_LHS_0:.+]] = vector.insert_strided_slice %[[SLICE_LHS_0]], %{{.*}}
767+
// CHECK: %[[SLICE_LHS_1:.+]] = vector.transfer_read %[[ALLOC_LHS]]
768+
// CHECK: %[[VECTOR_LHS_1:.+]] = vector.insert_strided_slice %[[SLICE_LHS_1]], %[[VECTOR_LHS_0]] {{.*}} : vector<1x4xf8E4M3FNUZ> into vector<1x4x1x2x1x4xf8E4M3FNUZ>
769+
770+
// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_RHS]]
771+
// CHECK: %[[SLICE_RHS_0:.+]] = vector.transfer_read %[[ALLOC_RHS]]
772+
// CHECK: %[[VECTOR_RHS_0:.+]] = vector.insert_strided_slice %[[SLICE_RHS_0]], %{{.*}}
773+
// CHECK: %[[SLICE_RHS_1:.+]] = vector.transfer_read %[[ALLOC_RHS]]
774+
// CHECK: %[[VECTOR_RHS_1:.+]] = vector.insert_strided_slice %[[SLICE_RHS_1]], %[[VECTOR_RHS_0]] {{.*}} : vector<4x1xf8E4M3FNUZ> into vector<4x1x2x1x4x1xf8E4M3FNUZ>
775+
776+
// CHECK: %[[EXTRACT_LHS:.+]] = vector.extract %[[VECTOR_LHS_1]][{{.*}}, {{.*}}] : vector<1x2x1x4xf8E4M3FNUZ> from vector<1x4x1x2x1x4xf8E4M3FNUZ>
777+
// CHECK: %[[EXTRACT_RHS:.+]] = vector.extract %[[VECTOR_RHS_1]][{{.*}}, {{.*}}] : vector<2x1x4x1xf8E4M3FNUZ> from vector<4x1x2x1x4x1xf8E4M3FNUZ>
778+
779+
// CHECK: %[[LHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_LHS]] : vector<1x2x1x4xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
780+
// CHECK: %[[RHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_RHS]] : vector<2x1x4x1xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
781+
// CHECK: amdgpu.mfma %[[LHS_CAST]] * %[[RHS_CAST]] + %{{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}
782+
// CHECK-SAME: : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
783+
784+
// Ensure right number of instructions are being generated.
785+
// CHECK-COUNT-3: amdgpu.mfma
786+
787+
// CHECK: scf.yield
788+
789+
// -----
790+
721791
#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
722792
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>
723793

tests/e2e/matmul/generate_e2e_matmul_tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def get_rocm_test_compilation_infos(
350350
MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
351351
MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
352352
MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
353+
MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
354+
MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
353355
]
354356
elif intrinsic == "WMMA":
355357
schedules = [
@@ -399,6 +401,7 @@ def get_rocm_test_compilation_infos(
399401
schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
400402
or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
401403
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
404+
or schedule.intrinsic == "VMFMA_F32_16x16x32_F8E4M3FNUZ"
402405
):
403406
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
404407
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16

0 commit comments

Comments
 (0)