Skip to content

Commit e02ca45

Browse files
kuharLukacma
authored andcommitted
[mlir][amdgpu] Update mfma assembly format with intrinsic shape (llvm#165037)
Use the same format as introduced for wmma by llvm#164920. Also make `blocks` default to 1.
1 parent b0e24bd commit e02ca45

File tree

6 files changed

+131
-110
lines changed

6 files changed

+131
-110
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -923,10 +923,10 @@ def AMDGPU_MFMAOp :
923923
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
924924
Pure]>,
925925
Arguments<(ins
926-
I32Attr:$m,
927-
I32Attr:$n,
928-
I32Attr:$k,
929-
I32Attr:$blocks,
926+
ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$m,
927+
ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$n,
928+
ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 8, 16, 32, 64, 128]>]>:$k,
929+
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 16]>]>, "1">:$blocks,
930930
MFMAInTypes:$sourceA,
931931
MFMAInTypes:$sourceB,
932932
MFMAOutTypes:$destC,
@@ -969,14 +969,16 @@ def AMDGPU_MFMAOp :
969969

970970
Example:
971971
```mlir
972-
%0 = amdgpu.mfma %matA * %matB + %matC
973-
{ abid = 1 : i32, cbsz = 1 : i32,
974-
m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
972+
%0 = amdgpu.mfma 16x16x16 %matA * %matB + %matC
973+
: vector<4xf16>, vector<4xf16>, vector<4xf32>
974+
975+
%1 = amdgpu.mfma 32x32x1 %matD * %matE + %matF
976+
{ abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 }
975977
blgp = bcast_second_32 : f32, f32, vector<32xf32>
976978
```
977979
}];
978980
let assemblyFormat = [{
979-
$sourceA `*` $sourceB `+` $destC
981+
custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
980982
attr-dict
981983
`blgp` `=` $blgp
982984
`:` type($sourceA) `,` type($sourceB) `,` type($destC)

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,11 @@ LogicalResult MFMAOp::verify() {
422422

423423
Type sourceElem = sourceType, destElem = destType;
424424
uint32_t sourceLen = 1, destLen = 1;
425-
if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
425+
if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
426426
sourceLen = sourceVector.getNumElements();
427427
sourceElem = sourceVector.getElementType();
428428
}
429-
if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
429+
if (auto destVector = dyn_cast<VectorType>(destType)) {
430430
destLen = destVector.getNumElements();
431431
destElem = destVector.getElementType();
432432
}
@@ -451,7 +451,7 @@ LogicalResult MFMAOp::verify() {
451451
return emitOpError("expected both non-small-float source operand types "
452452
"to match exactly");
453453
}
454-
// Normalize the wider integer types the compiler expects to i8
454+
// Normalize the wider integer types the compiler expects to i8.
455455
if (sourceElem.isInteger(32)) {
456456
sourceLen *= 4;
457457
sourceElem = b.getI8Type();

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,46 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
88
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
99

1010
// CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
11-
amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
11+
amdgpu.mfma 32x32x16 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
1212
// CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
13-
amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
13+
amdgpu.mfma 16x16x32 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
1414
// CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
15-
amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
15+
amdgpu.mfma 32x32x16 %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
1616
// CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
17-
amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
17+
amdgpu.mfma 16x16x32 %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
1818
// CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
19-
amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
19+
amdgpu.mfma 32x32x32 %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
2020
// CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
21-
amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
21+
amdgpu.mfma 16x16x64 %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
2222
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
23-
amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
23+
amdgpu.mfma 32x32x64 %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
2424
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
25-
amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
25+
amdgpu.mfma 16x16x128 %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
2626
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
2727
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
28-
amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
28+
amdgpu.mfma 32x32x64 %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
2929
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
30-
amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
30+
amdgpu.mfma 16x16x128 %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
3131
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
3232
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
33-
amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
33+
amdgpu.mfma 32x32x64 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
3434
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
35-
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
35+
amdgpu.mfma 16x16x128 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
3636
// CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
3737
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
38-
amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
38+
amdgpu.mfma 32x32x64 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
3939
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
40-
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
40+
amdgpu.mfma 16x16x128 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
4141
// CHECK-DAG: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
4242
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
43-
amdgpu.mfma %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
43+
amdgpu.mfma 32x32x64 %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
4444
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
45-
amdgpu.mfma %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
45+
amdgpu.mfma 16x16x128 %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
4646

4747
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
48-
amdgpu.mfma %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32>
48+
amdgpu.mfma 32x32x64 %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32>
4949
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
50-
amdgpu.mfma %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32>
50+
amdgpu.mfma 16x16x128 %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32>
5151

5252
func.return
5353
}
@@ -57,9 +57,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5757
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
5858
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
5959
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
60-
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
60+
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
6161
%arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
62-
62+
6363
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6464
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
6565
// CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
@@ -69,32 +69,32 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
6969
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
7070
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
7171
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
72-
72+
7373
// CHECK: llvm.bitcast
74-
74+
7575
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
7676
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
7777
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
7878
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
79-
79+
8080
// CHECK: llvm.bitcast
81-
81+
8282
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
8383
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
8484
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
8585
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
86-
86+
8787
// CHECK: llvm.bitcast
8888
// CHECK: llvm.mlir.constant(3 : i32) : i32
8989

9090
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
9191
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
9292
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
9393
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
94-
94+
9595
// CHECK: llvm.bitcast
9696
// CHECK: llvm.mlir.constant(4 : i32) : i32
97-
97+
9898
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
9999
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
100100
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>

0 commit comments

Comments
 (0)