Skip to content

Commit 4ad834b

Browse files
authored
Support F8E5M2FNUZ MFMA on CDNA3 (#18887)
F8E4M3FNUZ was already there. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent 2291b38 commit 4ad834b

File tree

4 files changed

+44
-13
lines changed

4 files changed

+44
-13
lines changed

compiler/plugins/target/ROCM/test/target_device_features.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// GFX942: target = #iree_gpu.target<arch = "gfx942",
1616
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
1717
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
18-
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
18+
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
1919
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
2020
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
2121
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
@@ -26,7 +26,7 @@
2626
// GFX941-SAME: features = "+sramecc,-xnack"
2727

2828
// GFX940: target = #iree_gpu.target<arch = "gfx940",
29-
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
29+
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
3030

3131
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
3232
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
212212
static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
213213
MMAIntrinsic type) {
214214
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
215+
Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
215216
Type f16 = Float16Type::get(context);
216217
Type f32 = Float32Type::get(context);
217218

@@ -231,6 +232,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
231232
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
232233
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
233234
}
235+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
236+
return OpaqueMmaLayout{16, 16, 32, f8E5M2FNUZ, f8E5M2FNUZ, f32};
237+
}
234238
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
235239
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
236240
}
@@ -472,6 +476,7 @@ MMAAttr::getABCVectorTypes() const {
472476
return std::make_tuple(aType, bType, cType);
473477
}
474478
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
479+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
475480
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
476481
auto aType = VectorType::get({8}, getAType());
477482
auto bType = VectorType::get({8}, getBType());
@@ -518,6 +523,7 @@ int64_t MMAAttr::getBlockSize() const {
518523
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
519524
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
520525
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
526+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
521527
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
522528
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
523529
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
@@ -538,6 +544,7 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
538544
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
539545
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
540546
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
547+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
541548
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
542549
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
543550
return 64;
@@ -602,6 +609,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
602609
/*element=*/{4, 1}};
603610
}
604611
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
612+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
605613
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
606614
switch (fragment) {
607615
case MMAFragment::Lhs:
@@ -699,6 +707,7 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
699707
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
700708
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
701709
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
710+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
702711
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
703712
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
704713
auto [m, n, k] = getMNKShape();

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,31 +99,52 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
9999
}
100100

101101
// Format: <kind>_<output-type>_<M>x<N>x<K>_<input-type>
102-
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
103-
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>;
104-
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>;
105-
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>;
106-
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>;
107-
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>;
102+
// Values: 0xABCD where:
103+
// * A = vendor:
104+
// * 0 = AMD
105+
// * 1 = NVIDIA
106+
// * B is architecture:
107+
// * For AMD:
108+
// * 0 = RDNA3
109+
// * 8 = CDNA2
110+
// * 9 = CDNA3
111+
// * C is A/B data type:
112+
// * 0 = f32
113+
// * 1 = f16
114+
// * 2 = bf16
115+
// * 3 = f8e5m2 (and variants like fnuz).
116+
// * 4 = f8e4m3 (and variants like fnuz).
117+
// * 8 = i8
118+
// * D enumerates intrinsics for the same data type.
119+
//
120+
// CDNA3 instrinsics
121+
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
122+
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
123+
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
124+
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
125+
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
126+
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
127+
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>;
108128

109129
// CDNA2 instrinsics
110-
def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 6>;
111-
def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 7>;
130+
def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x0880>;
131+
def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x0881>;
112132

113133
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
114-
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 8>;
115-
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 9>;
134+
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x0010>;
135+
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x0011>;
116136

117137
// TODO: The actual I8 instruction allows specifying (mixed) signedness.
118138
// This will need to become its own class of MMA attribute.
119-
def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 10>;
139+
def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x0080>;
120140

121141
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
122142
"Descriptor for different MMA intrinsics", [
123143
MFMA_F32_16x16x4_F32,
124144
MFMA_F32_16x16x16_F16,
125145
MFMA_F32_32x32x8_F16,
126146
MFMA_F32_16x16x32_F8E4M3FNUZ,
147+
MFMA_F32_16x16x32_F8E5M2FNUZ,
127148
MFMA_I32_16x16x32_I8,
128149
MFMA_I32_32x32x16_I8,
129150
MFMA_I32_16x16x16_I8,

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ const WgpDetails *getCDNA3WgpDetails() {
137137
MMAIntrinsic::MFMA_F32_16x16x16_F16,
138138
MMAIntrinsic::MFMA_F32_32x32x8_F16,
139139
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
140+
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
140141
MMAIntrinsic::MFMA_I32_16x16x32_I8,
141142
MMAIntrinsic::MFMA_I32_32x32x16_I8,
142143
};

0 commit comments

Comments
 (0)