Skip to content

Commit 20d9145

Browse files
kuharaokblast
authored andcommitted
[mlir][amdgpu][rocdl] Add gfx1250 wmma ops (llvm#165064)
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.
1 parent ff7bed9 commit 20d9145

File tree

6 files changed

+346
-50
lines changed

6 files changed

+346
-50
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
912912
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
913913
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
914914
// wmma
915-
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
916-
VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
917-
VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
915+
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
916+
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
917+
VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
918+
VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
918919
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
919920
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
920921
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
992993
Arguments<(ins
993994
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
994995
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
995-
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
996+
ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
996997
WMMAInTypes:$sourceA,
997998
WMMAInTypes:$sourceB,
998999
WMMAOutTypes:$destC,
@@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
10051006
let description = [{
10061007
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
10071008
instructions in the AMDGPU architecture, which perform matrix multiplication.
1008-
Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
1009-
dimensions.
1009+
1010+
On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
1011+
1012+
On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
1013+
all element types, and K=32 for i4 sources.
1014+
1015+
On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
1016+
depending on the element types.
10101017

10111018
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
10121019
(or 16xbf16) vector containing only 8 valid values:
@@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :
10221029

10231030
Example:
10241031
```mlir
1025-
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
1032+
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
1033+
1034+
%1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
1035+
1036+
%2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
1037+
1038+
%3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
10261039
```
10271040
}];
10281041
let assemblyFormat = [{

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 146 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
989989
smfma.getN(), smfma.getK(), 1u, chipset);
990990
}
991991

992-
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993-
/// if one exists. This includes checking to ensure the intrinsic is supported
994-
/// on the architecture you are compiling for.
995-
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
996-
Chipset chipset) {
997-
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
998-
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
999-
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1000-
Type elemSourceType = sourceVectorType.getElementType();
1001-
Type elemBSourceType = sourceBVectorType.getElementType();
1002-
Type elemDestType = destVectorType.getElementType();
1003-
1004-
const uint32_t k = wmma.getK();
1005-
992+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993+
/// for RDNA3/4 architectures.
994+
static std::optional<StringRef>
995+
wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
996+
Type elemDestType, uint32_t k, bool isRDNA3) {
997+
using fp8 = Float8E4M3FNType;
998+
using bf8 = Float8E5M2Type;
999+
1000+
// Handle k == 16 for RDNA3/4.
10061001
if (k == 16) {
1002+
// Common patterns for RDNA3 and RDNA4.
10071003
if (elemSourceType.isF16() && elemDestType.isF32())
10081004
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
10091005
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10141010
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
10151011
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
10161012
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1017-
if (chipset.majorVersion == 11) {
1013+
1014+
// RDNA3 specific patterns.
1015+
if (isRDNA3) {
10181016
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10191017
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1018+
return std::nullopt;
10201019
}
1021-
}
1022-
if (chipset.majorVersion < 12)
1023-
return std::nullopt;
10241020

1025-
// gfx12+
1026-
if (k == 16) {
1027-
if (isa<Float8E4M3FNType>(elemSourceType) &&
1028-
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1021+
// RDNA4 specific patterns (fp8/bf8).
1022+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023+
elemDestType.isF32())
10291024
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1030-
if (isa<Float8E4M3FNType>(elemSourceType) &&
1031-
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1025+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1026+
elemDestType.isF32())
10321027
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1033-
if (isa<Float8E5M2Type>(elemSourceType) &&
1034-
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1028+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1029+
elemDestType.isF32())
10351030
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1036-
if (isa<Float8E5M2Type>(elemSourceType) &&
1037-
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1031+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1032+
elemDestType.isF32())
10381033
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
10391034
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10401035
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
10411036

10421037
return std::nullopt;
10431038
}
1044-
if (k == 32) {
1039+
1040+
// Handle k == 32 for RDNA4.
1041+
if (k == 32 && !isRDNA3) {
10451042
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10461043
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1044+
}
1045+
1046+
llvm_unreachable("Unsupported k value");
1047+
}
1048+
1049+
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1050+
/// for the gfx1250 architecture.
1051+
static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1052+
Type elemBSourceType,
1053+
Type elemDestType,
1054+
uint32_t k) {
1055+
using fp8 = Float8E4M3FNType;
1056+
using bf8 = Float8E5M2Type;
1057+
1058+
if (k == 4) {
1059+
if (elemSourceType.isF32() && elemDestType.isF32())
1060+
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1061+
10471062
return std::nullopt;
10481063
}
10491064

1065+
if (k == 32) {
1066+
if (elemSourceType.isF16() && elemDestType.isF32())
1067+
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1068+
if (elemSourceType.isBF16() && elemDestType.isF32())
1069+
return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1070+
if (elemSourceType.isF16() && elemDestType.isF16())
1071+
return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1072+
if (elemSourceType.isBF16() && elemDestType.isBF16())
1073+
return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1074+
1075+
return std::nullopt;
1076+
}
1077+
1078+
if (k == 64) {
1079+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1080+
if (elemDestType.isF32())
1081+
return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1082+
if (elemDestType.isF16())
1083+
return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1084+
}
1085+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1086+
if (elemDestType.isF32())
1087+
return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1088+
if (elemDestType.isF16())
1089+
return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1090+
}
1091+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1092+
if (elemDestType.isF32())
1093+
return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1094+
if (elemDestType.isF16())
1095+
return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1096+
}
1097+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1098+
if (elemDestType.isF32())
1099+
return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1100+
if (elemDestType.isF16())
1101+
return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1102+
}
1103+
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1104+
return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1105+
1106+
return std::nullopt;
1107+
}
1108+
1109+
if (k == 128) {
1110+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1111+
if (elemDestType.isF32())
1112+
return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1113+
if (elemDestType.isF16())
1114+
return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1115+
}
1116+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1117+
if (elemDestType.isF32())
1118+
return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1119+
if (elemDestType.isF16())
1120+
return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1121+
}
1122+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1123+
if (elemDestType.isF32())
1124+
return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1125+
if (elemDestType.isF16())
1126+
return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1127+
}
1128+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1129+
if (elemDestType.isF32())
1130+
return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1131+
if (elemDestType.isF16())
1132+
return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1133+
}
1134+
1135+
return std::nullopt;
1136+
}
1137+
1138+
llvm_unreachable("Unsupported k value");
1139+
}
1140+
1141+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1142+
/// if one exists. This includes checking to ensure the intrinsic is supported
1143+
/// on the architecture you are compiling for.
1144+
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1145+
Chipset chipset) {
1146+
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1147+
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1148+
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1149+
Type elemSourceType = sourceVectorType.getElementType();
1150+
Type elemBSourceType = sourceBVectorType.getElementType();
1151+
Type elemDestType = destVectorType.getElementType();
1152+
1153+
const uint32_t k = wmma.getK();
1154+
const bool isRDNA3 = chipset.majorVersion == 11;
1155+
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1156+
1157+
// Handle RDNA3 and RDNA4.
1158+
if (isRDNA3 || isRDNA4)
1159+
return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1160+
k, isRDNA3);
1161+
1162+
// Handle gfx1250.
1163+
if (chipset == Chipset{12, 5, 0})
1164+
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1165+
elemDestType, k);
1166+
10501167
llvm_unreachable("unhandled WMMA case");
10511168
}
10521169

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
399399

400400
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
401401
return emitOpError(
402-
"source element types much match (except for fp8) but have ")
402+
"source element types must match (except for fp8/bf8) but have ")
403403
<< sourceAType << " and " << sourceBType;
404404
}
405405

406-
if (!sourceAElemType.isInteger(4) && getK() != 16) {
407-
return emitOpError("K dimension must be 16 for source element type ")
408-
<< sourceAElemType;
406+
if (isSrcFloat) {
407+
if (getClamp())
408+
return emitOpError("clamp flag is not supported for float types");
409+
if (getUnsignedA() || getUnsignedB())
410+
return emitOpError("unsigned flags are not supported for float types");
409411
}
410412
return success();
411413
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
2+
3+
// CHECK-LABEL: @wmma_k4
4+
func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
5+
// CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
6+
amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
7+
func.return
8+
}
9+
10+
// CHECK-LABEL: @wmma_k32
11+
func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
12+
%arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
13+
// CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
14+
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
15+
16+
// CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
17+
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
18+
19+
// CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
20+
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
21+
22+
// CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
23+
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
24+
25+
func.return
26+
}
27+
28+
// CHECK-LABEL: @wmma_k64
29+
func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
30+
%arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
31+
// CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
32+
amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
33+
34+
// CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
35+
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
36+
37+
// CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
38+
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
39+
40+
// CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
41+
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
42+
43+
// CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
44+
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
45+
46+
// CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
47+
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
48+
49+
// CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
50+
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
51+
52+
// CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
53+
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
54+
55+
// CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
56+
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
57+
58+
func.return
59+
}
60+
61+
// CHECK-LABEL: @wmma_k128
62+
func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
63+
%arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
64+
// CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
65+
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
66+
67+
// CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
68+
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
69+
70+
// CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
71+
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
72+
73+
// CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
74+
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
75+
76+
// CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
77+
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
78+
79+
// CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
80+
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
81+
82+
// CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
83+
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
84+
85+
// CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
86+
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
87+
88+
func.return
89+
}

0 commit comments

Comments
 (0)