From 6ed2feeff360c60f256a3135b96f54fa744c46a7 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 29 Oct 2025 15:34:43 -0400 Subject: [PATCH 1/2] [mlir][amdgpu][rocdl] Allow for graceful wmma conversion failures --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 +++--- .../Conversion/AMDGPUToROCDL/wmma-gfx11.mlir | 4 ++-- .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir | 13 +++++++++++- .../AMDGPUToROCDL/wmma-gfx1250.mlir | 20 ++++++++++++++----- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 1eca43d96fe85..41e333c621eda 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1043,7 +1043,7 @@ wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); } - llvm_unreachable("Unsupported k value"); + return std::nullopt; } /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` @@ -1135,7 +1135,7 @@ static std::optional wmmaOpToIntrinsicGfx1250(Type elemSourceType, return std::nullopt; } - llvm_unreachable("Unsupported k value"); + return std::nullopt; } /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` @@ -1164,7 +1164,7 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); - llvm_unreachable("unhandled WMMA case"); + return std::nullopt; } namespace { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir index d1301d0089220..9fcc1473d4a18 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s // CHECK-LABEL: @wmma_to_rocdl func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, @@ -32,5 +32,5 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> - func.return + return } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index b897323340402..57883473bbf06 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 \ +// RUN: --split-input-file --verify-diagnostics | FileCheck %s + // CHECK-LABEL: @wmma_to_rocdl func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, %arg2 : vector<8xf32>, %arg3 : vector<4xf32>, @@ -66,3 +68,12 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, func.return } + +// ----- + +func.func @wmma_unsupported_k(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<8xf16>) { + // expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}} + // expected-error@below {{failed to legalize operation 'amdgpu.wmma'}} + amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg1 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> + func.return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir index bcbdef040ebe3..267ae8bc4f4c0 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -1,10 +1,11 @@ -// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 \ +// RUN: --split-input-file --verify-diagnostics | FileCheck %s // CHECK-LABEL: @wmma_k4 func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) { // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1 amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32> - func.return + return } // CHECK-LABEL: @wmma_k32 @@ -22,7 +23,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> - func.return + return } // CHECK-LABEL: @wmma_k64 @@ -55,7 +56,7 @@ func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> - func.return + return } // CHECK-LABEL: @wmma_k128 @@ -85,5 +86,14 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> - func.return + return +} + +// ----- + +func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) { + // expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}} + // expected-error@below {{failed to legalize operation 'amdgpu.wmma'}} + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32> + return } From 30d4397511f868913704a2be3569ce8381ad65cb Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 29 Oct 2025 15:36:33 -0400 Subject: [PATCH 2/2] Naming --- mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir index 267ae8bc4f4c0..5e77a3add3184 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -91,7 +91,7 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, // ----- -func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) { +func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) { // expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}} // expected-error@below {{failed to legalize operation 'amdgpu.wmma'}} amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>