Skip to content

Commit b08ea12

Browse files
authored
[LLVMGPU] Add 32x32x16 F8 MFMA intrinsic (iree-org#19106)
To enable faster SDXL on attention we'd need different FP8 MFMA intrinsics. This 32x32x16 FP8 intrinsic (and virtual intrinsic for 2nd matmul) has been especially performant when used on this SDXL attention shape (B0: 2, B1: 10, (M, K2): 4096: K1: 64). --------- Signed-off-by: Stanley Winata <[email protected]>
1 parent 11fe5cd commit b08ea12

File tree

6 files changed

+169
-16
lines changed

6 files changed

+169
-16
lines changed

compiler/plugins/target/ROCM/test/target_device_features.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// GFX942: target = #iree_gpu.target<arch = "gfx942",
1616
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
1717
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
18-
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
18+
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
1919
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
2020
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
2121
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
@@ -26,7 +26,7 @@
2626
// GFX941-SAME: features = "+sramecc,-xnack"
2727

2828
// GFX940: target = #iree_gpu.target<arch = "gfx940",
29-
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
29+
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
3030

3131
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
3232
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>]

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,18 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
256256
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
257257
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
258258
}
259+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
260+
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
261+
}
262+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
263+
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
264+
}
265+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
266+
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
267+
}
268+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
269+
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
270+
}
259271
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
260272
return {i8, i8, i32};
261273
}
@@ -608,6 +620,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
608620
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
609621
/*element=*/{4, 1}};
610622
}
623+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
624+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
625+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
626+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
611627
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
612628
switch (fragment) {
613629
case MMAFragment::Lhs:
@@ -675,6 +691,8 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
675691
return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16};
676692
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
677693
return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ};
694+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
695+
return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ};
678696
default:
679697
return {};
680698
}
@@ -1218,6 +1236,9 @@ static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context,
12181236
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
12191237
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
12201238
}
1239+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
1240+
return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32};
1241+
}
12211242
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
12221243
// along the k dimension.
12231244
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
@@ -1252,6 +1273,7 @@ VirtualMMAAttr::getABCVectorTypes() const {
12521273
auto cType = VectorType::get({4}, C);
12531274
return {aType, bType, cType};
12541275
}
1276+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12551277
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
12561278
auto aType = VectorType::get({8}, A);
12571279
auto bType = VectorType::get({8}, B);
@@ -1274,6 +1296,7 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
12741296
switch (getIntrinsic().getValue()) {
12751297
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
12761298
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1299+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12771300
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
12781301
return 64;
12791302
}
@@ -1328,7 +1351,8 @@ int64_t VirtualMMAAttr::getUnrollK() const {
13281351
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13291352
return 2;
13301353
}
1331-
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
1354+
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1355+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
13321356
return 1;
13331357
}
13341358
}
@@ -1356,6 +1380,7 @@ FailureOr<Value> VirtualMMAAttr::buildMmaOperation(OpBuilder &builder,
13561380
switch (getIntrinsic().getValue()) {
13571381
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13581382
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1383+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13591384
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13601385
// Generate mfma's for K with unrolled kernels.
13611386
const int64_t unrollKFactor = getUnrollK();
@@ -1394,6 +1419,7 @@ int64_t VirtualMMAAttr::getBlockSize() const {
13941419
switch (getIntrinsic().getValue()) {
13951420
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13961421
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1422+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13971423
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13981424
return 1;
13991425
}
@@ -1442,6 +1468,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
14421468
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
14431469
/*element=*/{4, 1}};
14441470
}
1471+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1472+
switch (fragment) {
1473+
case MMAFragment::Lhs:
1474+
return {/*outer=*/{1, 2}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
1475+
/*element=*/{1, 4}};
1476+
case MMAFragment::Rhs:
1477+
return {/*outer=*/{2, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
1478+
/*element=*/{4, 1}};
1479+
case MMAFragment::Acc:
1480+
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
1481+
/*element=*/{4, 1}};
1482+
}
14451483
}
14461484
assert(false && "unhandled virtual mma layout type.");
14471485
return {};

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ
158158
def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>;
159159
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>;
160160
def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>;
161+
def MFMA_F32_32x32x16_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ", 0x1234>;
162+
def MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ", 0x1235>;
163+
def MFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ", 0x1236>;
164+
def MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ", 0x1237>;
161165
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>;
162166
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>;
163167

@@ -193,6 +197,10 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
193197
MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
194198
MFMA_F32_16x16x32_F8E4M3FNUZ,
195199
MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
200+
MFMA_F32_32x32x16_F8E5M2FNUZ,
201+
MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ,
202+
MFMA_F32_32x32x16_F8E4M3FNUZ,
203+
MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ,
196204
MFMA_I32_16x16x32_I8,
197205
MFMA_I32_32x32x16_I8,
198206

@@ -211,12 +219,14 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
211219
def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0>;
212220
def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 1>;
213221
def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 2>;
222+
def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3FNUZ", 3>;
214223

215224
def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32MmaEnumAttr<"VirtualMMAIntrinsic",
216225
"Descriptor for different Virtual MMA intrinsics", [
217226
VMFMA_F32_16x16x32_F16,
218227
VMFMA_F32_32x32x16_F16,
219228
VMFMA_F32_16x16x32_F8E4M3FNUZ,
229+
VMFMA_F32_32x32x16_F8E4M3FNUZ,
220230
]>;
221231

222232
def MMA_LHS : I32EnumAttrCase<"Lhs", 0>;

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ const WgpDetails *getCDNA3WgpDetails() {
146146
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
147147
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
148148
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
149+
MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ,
150+
MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ,
151+
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ,
152+
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ,
149153
MMAIntrinsic::MFMA_I32_16x16x32_I8,
150154
MMAIntrinsic::MFMA_I32_32x32x16_I8,
151155
};

0 commit comments

Comments
 (0)