Skip to content

Commit 5734253

Browse files
kuharaokblast
authored andcommitted
[mlir][amdgpu][rocdl] Allow for graceful wmma conversion failures (llvm#165616)
1 parent bc5b3a9 commit 5734253

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
10431043
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
10441044
}
10451045

1046-
llvm_unreachable("Unsupported k value");
1046+
return std::nullopt;
10471047
}
10481048

10491049
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1135,7 +1135,7 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
11351135
return std::nullopt;
11361136
}
11371137

1138-
llvm_unreachable("Unsupported k value");
1138+
return std::nullopt;
11391139
}
11401140

11411141
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1164,7 +1164,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
11641164
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
11651165
elemDestType, k);
11661166

1167-
llvm_unreachable("unhandled WMMA case");
1167+
return std::nullopt;
11681168
}
11691169

11701170
namespace {

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s
22

33
// CHECK-LABEL: @wmma_to_rocdl
44
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
@@ -32,5 +32,5 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
3232
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
3333
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
3434

35-
func.return
35+
return
3636
}

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 \
2+
// RUN: --split-input-file --verify-diagnostics | FileCheck %s
3+
24
// CHECK-LABEL: @wmma_to_rocdl
35
func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
46
%arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -66,3 +68,12 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
6668

6769
func.return
6870
}
71+
72+
// -----
73+
74+
func.func @wmma_unsupported_k(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<8xf16>) {
75+
// expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
76+
// expected-error@below {{failed to legalize operation 'amdgpu.wmma'}}
77+
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg1 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
78+
func.return
79+
}

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 \
2+
// RUN: --split-input-file --verify-diagnostics | FileCheck %s
23

34
// CHECK-LABEL: @wmma_k4
45
func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
56
// CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
67
amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
7-
func.return
8+
return
89
}
910

1011
// CHECK-LABEL: @wmma_k32
@@ -22,7 +23,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
2223
// CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
2324
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
2425

25-
func.return
26+
return
2627
}
2728

2829
// CHECK-LABEL: @wmma_k64
@@ -55,7 +56,7 @@ func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 :
5556
// CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
5657
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
5758

58-
func.return
59+
return
5960
}
6061

6162
// CHECK-LABEL: @wmma_k128
@@ -85,5 +86,14 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
8586
// CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
8687
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
8788

88-
func.return
89+
return
90+
}
91+
92+
// -----
93+
94+
func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
95+
// expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
96+
// expected-error@below {{failed to legalize operation 'amdgpu.wmma'}}
97+
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>
98+
return
8999
}

0 commit comments

Comments
 (0)