Skip to content

Commit 77eaa97

Browse files
Addressing review comments
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent b768b8a commit 77eaa97

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13001300
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
13011301
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
13021302

1303-
bool isGFX1250 = chipset == Chipset(12, 5, 0);
1303+
bool isGFX1250 = chipset >= Chipset(12, 5, 0);
13041304

13051305
// The WMMA operations represent vectors of bf16s as vectors of i16s
13061306
// (except on gfx1250), so we need to bitcast bfloats to i16 and then

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
66
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
77
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
88
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
9-
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
9+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<8xf32>
1010
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
11-
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
11+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<4xf32>
1212
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
13-
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
13+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
1414
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
15-
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
15+
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
1616
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16>
1717
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
1818
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16>
@@ -23,13 +23,13 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
2323
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16>
2424
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
2525
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
26-
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
26+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
2727
amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
28-
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
28+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
2929
amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
30-
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
30+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
3131
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
32-
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
32+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
3333
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
3434

3535
return

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
5151
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
5252
amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
5353

54-
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
54+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
5555
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
56-
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
56+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
5757
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
5858

59-
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
59+
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
6060
amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
61-
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
61+
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
6262
amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
6363

64-
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<8xi32>) -> vector<8xi32>
64+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<8xi32>) -> vector<8xi32>
6565
amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
66-
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
66+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
6767
amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
6868

6969
func.return

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
2929
// CHECK-LABEL: @wmma_k64
3030
func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
3131
%arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
32-
// CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {{.*}}
32+
// CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {clamp = true, signA = true, signB = true}
3333
amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
3434

3535
// CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4

0 commit comments

Comments
 (0)