Skip to content

Commit 72ae51d

Browse files
kuharaokblast
authored andcommitted
[mlir][amdgpu] Add explicit intrinsic shape to wmma (llvm#164920)
This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes. Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.
1 parent 4c72b09 commit 72ae51d

File tree

9 files changed

+225
-118
lines changed

9 files changed

+225
-118
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
912912
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
913913
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
914914
// wmma
915-
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
916-
[4, 8, 16],
917-
[F16, BF16,
918-
I8, SI8, UI8,
919-
I<4>, SI<4>, UI<4>,
920-
F8E4M3FN, F8E5M2]>]>;
915+
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
916+
VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
917+
VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
918+
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
921919
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
922920
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
923921

@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
968966

969967
The negateA, negateB, and negateC flags are only supported for double-precision
970968
operations on gfx94x.
969+
970+
Example:
971+
```mlir
972+
%0 = amdgpu.mfma %matA * %matB + %matC
973+
{ abid = 1 : i32, cbsz = 1 : i32,
974+
m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
975+
blgp = bcast_second_32 : f32, f32, vector<32xf32>
976+
```
971977
}];
972978
let assemblyFormat = [{
973979
$sourceA `*` $sourceB `+` $destC
@@ -982,36 +988,43 @@ def AMDGPU_WMMAOp :
982988
AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
983989
Pure]>,
984990
Arguments<(ins
991+
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
992+
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
993+
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
985994
WMMAInTypes:$sourceA,
986995
WMMAInTypes:$sourceB,
987996
WMMAOutTypes:$destC,
988-
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
997+
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>, "0">:$subwordOffset,
989998
UnitAttr:$unsignedA,
990999
UnitAttr:$unsignedB,
9911000
UnitAttr:$clamp)>,
9921001
Results<(outs WMMAOutTypes: $destD)> {
993-
let summary = "MLIR wrapper for RDNA3 wmma instructions";
1002+
let summary = "MLIR wrapper for wmma instructions";
9941003
let description = [{
995-
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
996-
for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
997-
perform a 16x16 * 16x16 matrix multiplication for different data types.
998-
Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
999-
integer inputs.
1004+
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
1005+
instructions in the AMDGPU architecture, which perform matrix multiplication.
1006+
Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
1007+
dimensions.
10001008

10011009
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
10021010
(or 16xbf16) vector containing only 8 valid values:
10031011
- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
10041012
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
1005-
On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
1006-
all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
1013+
On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
1014+
the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
10071015

10081016
`unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
10091017

1010-
The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
1018+
The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
10111019
in case of overflow.
1020+
1021+
Example:
1022+
```mlir
1023+
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
1024+
```
10121025
}];
10131026
let assemblyFormat = [{
1014-
$sourceA `*` $sourceB `+` $destC
1027+
custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
10151028
attr-dict
10161029
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
10171030
}];

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file declares a dialect for MLIR wrappers around AMDGPU-specific
10-
// intrinssics and for other AMD GPU-specific functionality.
10+
// intrinsics and for other AMD GPU-specific functionality.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -26,6 +26,29 @@
2626

2727
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
2828

29+
namespace mlir::amdgpu {
30+
/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
31+
/// WMMAOp.
32+
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
33+
IntegerAttr &n, IntegerAttr &k);
34+
inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
35+
IntegerAttr &m, IntegerAttr &n,
36+
IntegerAttr &k) {
37+
return parseMNKDimensionList(parser, m, n, k);
38+
}
39+
40+
/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
41+
/// WMMAOp.
42+
inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
43+
IntegerAttr n, IntegerAttr k) {
44+
printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
45+
}
46+
inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
47+
IntegerAttr m, IntegerAttr n, IntegerAttr k) {
48+
printMNKDimensionList(printer, m, n, k);
49+
}
50+
} // namespace mlir::amdgpu
51+
2952
#define GET_ATTRDEF_CLASSES
3053
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
3154

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,11 @@ def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
804804

805805
class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
806806

807+
class IntIsOneOf<list<int> values> : AttrConstraint<
808+
CPred<"::llvm::is_contained({" # !interleave(!foreach(val, values, val), ", ") #
809+
"}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
810+
"whose value is one of {" # !interleave(!foreach(val, values, val), ", ") # "}">;
811+
807812
class ArrayMaxCount<int n> : AttrConstraint<
808813
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
809814
"with at most " # n # " elements">;

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1818
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
19+
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/TypeUtilities.h"
2122
#include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
993994
/// on the architecture you are compiling for.
994995
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
995996
Chipset chipset) {
996-
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
997-
auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
998-
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
999-
auto elemSourceType = sourceVectorType.getElementType();
1000-
auto elemBSourceType = sourceBVectorType.getElementType();
1001-
auto elemDestType = destVectorType.getElementType();
1002-
1003-
if (elemSourceType.isF16() && elemDestType.isF32())
1004-
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1005-
if (elemSourceType.isBF16() && elemDestType.isF32())
1006-
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1007-
if (elemSourceType.isF16() && elemDestType.isF16())
1008-
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1009-
if (elemSourceType.isBF16() && elemDestType.isBF16())
1010-
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1011-
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1012-
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1013-
if (chipset.majorVersion == 11) {
1014-
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1015-
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
997+
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
998+
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
999+
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1000+
Type elemSourceType = sourceVectorType.getElementType();
1001+
Type elemBSourceType = sourceBVectorType.getElementType();
1002+
Type elemDestType = destVectorType.getElementType();
1003+
1004+
const uint32_t k = wmma.getK();
1005+
1006+
if (k == 16) {
1007+
if (elemSourceType.isF16() && elemDestType.isF32())
1008+
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1009+
if (elemSourceType.isBF16() && elemDestType.isF32())
1010+
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1011+
if (elemSourceType.isF16() && elemDestType.isF16())
1012+
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1013+
if (elemSourceType.isBF16() && elemDestType.isBF16())
1014+
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1015+
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1016+
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1017+
if (chipset.majorVersion == 11) {
1018+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1019+
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1020+
}
10161021
}
1017-
if (chipset.majorVersion >= 12) {
1022+
if (chipset.majorVersion < 12)
1023+
return std::nullopt;
1024+
1025+
// gfx12+
1026+
if (k == 16) {
10181027
if (isa<Float8E4M3FNType>(elemSourceType) &&
10191028
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
10201029
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10271036
if (isa<Float8E5M2Type>(elemSourceType) &&
10281037
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
10291038
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1030-
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
1031-
bool isWave64 = destVectorType.getNumElements() == 4;
1032-
// This is the ambiguous case. 8 inputs to the wave64 version means that
1033-
// we want the 16x16x32 version, but for wave32 they mean the short form.
1034-
bool has8Inputs = sourceVectorType.getNumElements() == 8;
1035-
if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
1036-
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1039+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10371040
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1038-
}
1041+
1042+
return std::nullopt;
10391043
}
1040-
return std::nullopt;
1044+
if (k == 32) {
1045+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1046+
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1047+
return std::nullopt;
1048+
}
1049+
1050+
llvm_unreachable("unhandled WMMA case");
10411051
}
10421052

10431053
namespace {

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
360360
//===----------------------------------------------------------------------===//
361361
// WMMAOp
362362
//===----------------------------------------------------------------------===//
363-
LogicalResult WMMAOp::verify() {
364-
Type sourceAType = getSourceA().getType();
365-
Type sourceBType = getSourceB().getType();
366-
Type destType = getDestC().getType();
367363

368-
VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
369-
VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
370-
VectorType destVectorType = dyn_cast<VectorType>(destType);
364+
ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
365+
IntegerAttr &m, IntegerAttr &n,
366+
IntegerAttr &k) {
367+
SmallVector<int64_t, 3> dimensions;
368+
if (parser.parseDimensionList(dimensions, false, false))
369+
return failure();
370+
if (dimensions.size() != 3)
371+
return parser.emitError(parser.getCurrentLocation())
372+
<< "expected 3 dimensions in MNK dimension list";
371373

372-
Type sourceAElemType = sourceVectorAType.getElementType();
373-
Type sourceBElemType = sourceVectorBType.getElementType();
374-
Type destElemType = destVectorType.getElementType();
374+
m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
375+
n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
376+
k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
377+
return success();
378+
}
375379

376-
if (sourceVectorAType.getNumElements() !=
377-
sourceVectorBType.getNumElements()) {
380+
LogicalResult WMMAOp::verify() {
381+
auto sourceAType = cast<VectorType>(getSourceA().getType());
382+
auto sourceBType = cast<VectorType>(getSourceB().getType());
383+
auto destType = cast<VectorType>(getDestC().getType());
384+
385+
Type sourceAElemType = sourceAType.getElementType();
386+
Type sourceBElemType = sourceBType.getElementType();
387+
if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
378388
return emitOpError("source vectors have different lengths: ")
379-
<< sourceVectorAType << " vs. " << sourceVectorBType;
389+
<< sourceAType << " vs. " << sourceBType;
380390
}
381391

382-
bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
383-
bool isSrcFloat =
384-
isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
385-
sourceAElemType);
386-
387-
if (isDestFloat && !isSrcFloat) {
388-
return emitOpError("Expected float sources with float destination");
389-
}
392+
bool isDestFloat = destType.getElementType().isFloat();
393+
bool isSrcFloat = sourceAElemType.isFloat();
390394

391-
if (!isDestFloat && isSrcFloat) {
392-
return emitOpError("Expected int sources with int destination");
393-
}
395+
if (isDestFloat && !isSrcFloat)
396+
return emitOpError("expected float sources with float destination");
397+
if (!isDestFloat && isSrcFloat)
398+
return emitOpError("expected int sources with int destination");
394399

395-
if (sourceAElemType != sourceBElemType &&
396-
!(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
397-
isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
400+
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
398401
return emitOpError(
399402
"source element types much match (except for fp8) but have ")
400403
<< sourceAType << " and " << sourceBType;
401404
}
405+
406+
if (!sourceAElemType.isInteger(4) && getK() != 16) {
407+
return emitOpError("K dimension must be 16 for source element type ")
408+
<< sourceAElemType;
409+
}
402410
return success();
403411
}
404412

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
1-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
2+
23
// CHECK-LABEL: @wmma_to_rocdl
34
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
45
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
56
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
67
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
78
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
8-
amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
9+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
910
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
10-
amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
11+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
1112
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
12-
amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
13+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
1314
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
14-
amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
15+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
1516
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
16-
amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
17+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
1718
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
18-
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
19+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
1920
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
2021
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
21-
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
22+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
2223
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
2324
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
24-
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
25+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
2526
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
26-
amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
27+
amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
2728
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
28-
amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
29+
amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
2930
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
30-
amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
31+
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
3132
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
32-
amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
33+
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
3334

3435
func.return
3536
}

0 commit comments

Comments
 (0)