Skip to content

Commit e5579b9

Browse files
committed
Update description of index attributes
1 parent b1d2aa2 commit e5579b9

File tree

4 files changed

+56
-38
lines changed

4 files changed

+56
-38
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,38 +1248,54 @@ def AMDGPU_ScaledWMMAOp
12481248
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
12491249
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleBIdx)>,
12501250
Results<(outs ScaledWMMAOutTypes:$destD)> {
1251+
// TODO: E5M3FNU scales are supported, but there is not yet MLIR support for
1252+
// this datatype. Once we have support for that, update the scaleA and scaleB
1253+
// types here.
12511254
let summary = "MLIR wrapper for scaled wmma instructions";
12521255
let description = [{
12531256
The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
12541257
`wmma` instructions. These instructions perform matrix multiplication with
12551258
per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
12561259

1257-
The scale instructions support two tile sizes:
1260+
The scale instructions support a block size of 16 or 32 and two tile sizes:
12581261
- 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
12591262
- 32x16x128 with f4 format only (output: vector<8xf32>)
12601263

12611264
Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
1262-
(either f8E8M0FNU, or f8E4M3FN). The index attributes (`scaleAIdx`, `scaleBIdx`)
1263-
select which element from the scale vector to use for scaling. During lowering,
1264-
these vectors are packed into i32/i64 values for the hardware intrinsics.
1265+
(either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during
1266+
lowering. The index attributes (`scaleAIdx`, `scaleBIdx`) select which register
1267+
lanes provide scale values:
1268+
- Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half
1269+
a VGPR, with `scaleAIdx`/`scaleBIdx` selecting lanes 0-15 (index=0) or
1270+
16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in
1271+
a full VGPR (`scaleAIdx` is unused), while matrix B gets 64 scales in
1272+
half a VGPR.
1273+
1274+
- Block size 16: For a tile size of 16x16x128, each matrix gets
1275+
128 scales stored in half of two VGPRs, with `scaleAIdx`/`scaleBIdx`
1276+
selecting lanes 0-15 (index=0) or 16-31 (index=1) for each of the VGPRs.
1277+
For 32x16x128, matrix A gets 256 scales in two VGPRs (`scaleAIdx` is unused),
1278+
while matrix B gets 128 scales stored in half of two VGPRs.
12651279

12661280
Example:
12671281
```mlir
12681282
// 16x16x128: fp8 inputs
1269-
%0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA[0] * %matA) * (%scaleVecB[0] * %matB) + %matC
1283+
%0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA * %matA) * (%scaleVecB * %matB) + %matC
1284+
{scaleAIdx = 0 : i32, scaleBIdx = 0 : i32}
12701285
: vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>,
12711286
vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
12721287

1273-
// 32x16x128: fp4 inputs
1274-
%1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecC[1] * %matD) * (%scaleVecD[0] * %matE) + %matF
1288+
// 32x16x128: fp4 inputs with different scale indices
1289+
%1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecD * %matD) * (%scaleVecE * %matE) + %matF
1290+
{scaleAIdx = 0 : i32, scaleBIdx = 1 : i32}
12751291
: vector<8xf8E4M3FN>, vector<128xf4E2M1FN>,
12761292
vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
12771293
```
12781294
}];
12791295
let assemblyFormat = [{
12801296
custom<MNKDimensionList>($m, $n, $k) ` `
1281-
`(` $scaleA `[` $scaleAIdx `]` `*` $sourceA `)` `*`
1282-
`(` $scaleB `[` $scaleBIdx `]` `*` $sourceB `)` `+` $destC
1297+
`(` $scaleA `*` $sourceA `)` `*`
1298+
`(` $scaleB `*` $sourceB `)` `+` $destC
12831299
attr-dict
12841300
`:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
12851301
}];

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
674674
// Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
675675
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
676676
int64_t numElements = vectorType.getNumElements();
677+
assert((numElements == 4 || numElements == 8) &&
678+
"scale operand must be a vector of length 4 or 8");
677679
IntegerType outputType =
678680
(numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
679681
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
@@ -691,7 +693,7 @@ static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
691693
}
692694

693695
/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
694-
/// and scale vector length.
696+
/// and scale block size (16 or 32).
695697
static std::optional<StringRef>
696698
getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) {
697699
if (m == 16 && n == 16 && k == 128)

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
9393
func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
9494
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
9595
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
96-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
96+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
9797

9898
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
99-
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3[1] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
99+
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 1 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
100100

101101
func.return
102102
}
@@ -105,10 +105,10 @@ func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<
105105
func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
106106
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
107107
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
108-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
108+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
109109

110110
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
111-
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
111+
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
112112

113113
func.return
114114
}
@@ -118,10 +118,10 @@ func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
118118
%arg2 : vector<64xf4E2M1FN>, %arg3 : vector<8xf32>,
119119
%arg4 : vector<4xf8E8M0FNU>, %arg5 : vector<4xf8E4M3FN>) {
120120
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
121-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg4[0] * %arg0) * (%arg5[0] * %arg2) + %arg3 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
121+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg4 * %arg0) * (%arg5 * %arg2) + %arg3 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
122122

123123
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<12xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
124-
%1 = amdgpu.scaled_wmma 16x16x128 (%arg4[0] * %arg1) * (%arg5[0] * %arg2) + %arg3 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
124+
%1 = amdgpu.scaled_wmma 16x16x128 (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
125125

126126
func.return
127127
}
@@ -130,10 +130,10 @@ func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
130130
func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E3M2FN>,
131131
%arg2 : vector<8xf32>, %arg3 : vector<8xf8E8M0FNU>) {
132132
// CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
133-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
133+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
134134

135135
// CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
136-
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3[1] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
136+
%1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 1 : i32, scaleBIdx = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
137137

138138
func.return
139139
}
@@ -142,7 +142,7 @@ func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
142142
func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
143143
%arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
144144
// CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
145-
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
145+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
146146

147147
func.return
148148
}
@@ -151,7 +151,7 @@ func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector
151151
func.func @wmma_scale16_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
152152
%arg2 : vector<16xf32>, %arg3 : vector<8xf8E4M3FN>) {
153153
// CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
154-
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
154+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
155155

156156
func.return
157157
}
@@ -170,42 +170,42 @@ func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
170170
func.func @scaled_wmma_wrong_output_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<16xf32>,
171171
%arg2 : vector<4xf8E8M0FNU>) {
172172
// expected-error@below {{'amdgpu.scaled_wmma' op expected output vector of length 8 but got 16}}
173-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg2[0] * %arg0) * (%arg2[0] * %arg0) + %arg1 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<16xf32>
173+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg2 * %arg0) * (%arg2 * %arg0) + %arg1 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<16xf32>
174174
return
175175
}
176176

177177
func.func @scaled_wmma_16x16_wrong_sourceA_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
178178
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
179179
// expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceA must have 64 elements but got 128}}
180-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf4E2M1FN>, vector<8xf32>
180+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf4E2M1FN>, vector<8xf32>
181181
return
182182
}
183183

184184
func.func @scaled_wmma_16x16_wrong_sourceB_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<128xf4E2M1FN>,
185185
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
186186
// expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceB must have 64 elements but got 128}}
187-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<8xf32>
187+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<8xf32>
188188
return
189189
}
190190

191191
func.func @scaled_wmma_32x16_wrong_sourceA_length(%arg0 : vector<64xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
192192
%arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
193193
// expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceA must have 128 elements but got 64}}
194-
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
194+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
195195
return
196196
}
197197

198198
func.func @scaled_wmma_32x16_wrong_sourceB_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<128xf4E2M1FN>,
199199
%arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
200200
// expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceB must have 64 elements but got 128}}
201-
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<16xf32>
201+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<16xf32>
202202
return
203203
}
204204

205205
func.func @scaled_wmma_invalid_type_combination(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
206206
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>,
207207
%arg4 : vector<4xf8E4M3FN>) {
208208
// expected-error@below {{'amdgpu.scaled_wmma' op invalid combination of matrix and scale types}}
209-
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg4[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf6E2M3FN>, vector<8xf32>
209+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2 {scaleAIdx = 0 : i32, scaleBIdx = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf6E2M3FN>, vector<8xf32>
210210
return
211211
}

0 commit comments

Comments
 (0)