From f566fa78261aab88fcb77f06437b53119e52fb3f Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 30 Oct 2025 08:27:40 +0000 Subject: [PATCH 1/8] [MLIR][NVVM] Update Op verifiers to prevent ungraceful exits Updates the following Ops to prevent ungraceful exits with a stack-dump in certain cases of incorrect usages, and instead gracefully error out with a more informative error message: - tcgen05.ld - shfl.sync --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 24 ++++++++++++++------- mlir/test/Dialect/LLVMIR/invalid.mlir | 7 ++++++ mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 +++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4dbcc1d4b..402c90fba0f2d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -867,15 +867,20 @@ LogicalResult MmaOp::verify() { } LogicalResult ShflOp::verify() { - if (!(*this)->getAttrOfType("return_value_and_is_valid")) - return success(); auto type = llvm::dyn_cast(getType()); - auto elementType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast(type.getBody()[1]) - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); + + if ((*this)->getAttrOfType("return_value_and_is_valid")) { + auto elementType = (type && type.getBody().size() == 2) + ? llvm::dyn_cast(type.getBody()[1]) + : nullptr; + if (!elementType || elementType.getWidth() != 1) + return emitOpError("expected return type to be a two-element struct with " + "i1 as the second element"); + } else { + if (type) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when returning the predicate"); + } return success(); } @@ -2450,6 +2455,9 @@ LogicalResult Tcgen05LdOp::verify() { LogicalResult result = success(); if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); + + if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) + result = emitError("offset argument is only supported for shape 16x32bx2"); auto resTy = getRes().getType(); unsigned resLen = isa(resTy) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index aaf9f8024bfbe..90208aa55bd55 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -684,6 +684,13 @@ func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 // ----- +func.func @nvvm_invalid_shfl_pred_4(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> +} + +// ----- + func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 09b8f593154b5..8cb7b068498fd 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -621,3 +621,11 @@ func.func @invalid_range_equal_bounds() { %0 = nvvm.read.ptx.sreg.warpsize range : i32 return } + +// ----- + +llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () { + // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}} + %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape} : vector<2 x i32> + llvm.return +} From 045ce6c99786cb8634ca885dfca5b94a58ce2b8a Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 30 Oct 2025 08:37:55 +0000 Subject: [PATCH 2/8] fix formatting --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 402c90fba0f2d..a23245f92cee7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -2455,7 +2455,7 @@ LogicalResult Tcgen05LdOp::verify() { LogicalResult result = success(); if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); - + if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) result = emitError("offset argument is only supported for shape 16x32bx2"); From 70fe4cc7f10f3f4c4f1787d838676bb2b71c0bcd Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 31 Oct 2025 06:00:23 +0000 Subject: [PATCH 3/8] address comments --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 ++++++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 ++-- mlir/test/Dialect/LLVMIR/invalid.mlir | 7 ------- mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir | 8 ++++++++ mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir | 9 +++++++++ mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 -------- 6 files changed, 25 insertions(+), 19 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir create mode 100644 mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4f483859ac18d..1e915e3027d58 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr; def NVVM_ShflOp : NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>, - Results<(outs LLVM_Type:$res)>, + Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>, Arguments<(ins I32:$thread_mask, - LLVM_Type:$val, + AnyTypeOf<[I32, F32]>:$val, I32:$offset, I32:$mask_and_clamp, ShflKindAttr:$kind, @@ -1359,6 +1359,10 @@ def NVVM_ShflOp : a mask for logically splitting warps into sub-segments and an upper bound for clamping the source lane index. + Optionally, `return_value_and_is_valid` can be specified to return a + two-element struct with the result and a predicate indicating if the + computed source lane index is valid. + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) }]; string llvmBuilder = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a23245f92cee7..b5b07929bab6a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -870,10 +870,10 @@ LogicalResult ShflOp::verify() { auto type = llvm::dyn_cast(getType()); if ((*this)->getAttrOfType("return_value_and_is_valid")) { - auto elementType = (type && type.getBody().size() == 2) + auto predicateType = (type && type.getBody().size() == 2) ? llvm::dyn_cast(type.getBody()[1]) : nullptr; - if (!elementType || elementType.getWidth() != 1) + if (!predicateType || predicateType.getWidth() != 1) return emitOpError("expected return type to be a two-element struct with " "i1 as the second element"); } else { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 90208aa55bd55..aaf9f8024bfbe 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -684,13 +684,6 @@ func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 // ----- -func.func @nvvm_invalid_shfl_pred_4(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}} - %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> -} - -// ----- - func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir new file mode 100644 index 0000000000000..d2fe21c841a76 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir new file mode 100644 index 0000000000000..1b93f20c15b99 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () { + // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}} + %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape} : vector<2 x i32> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8cb7b068498fd..09b8f593154b5 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -621,11 +621,3 @@ func.func @invalid_range_equal_bounds() { %0 = nvvm.read.ptx.sreg.warpsize range : i32 return } - -// ----- - -llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () { - // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}} - %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape} : vector<2 x i32> - llvm.return -} From 29aa80ec9c3cc5231393d16ffb4d2aaa5c588b69 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 31 Oct 2025 06:06:38 +0000 Subject: [PATCH 4/8] fix formatting --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index b5b07929bab6a..33f9a256a78d0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -871,8 +871,8 @@ LogicalResult ShflOp::verify() { if ((*this)->getAttrOfType("return_value_and_is_valid")) { auto predicateType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast(type.getBody()[1]) - : nullptr; + ? llvm::dyn_cast(type.getBody()[1]) + : nullptr; if (!predicateType || predicateType.getWidth() != 1) return emitOpError("expected return type to be a two-element struct with " "i1 as the second element"); From 867f33c14e877dd89d0e5f2e04be6b996482fb88 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Mon, 3 Nov 2025 12:25:56 +0000 Subject: [PATCH 5/8] update and refactor shfl.sync Op verifier --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 5 ++- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 38 +++++++++++++------ mlir/test/Dialect/LLVMIR/invalid.mlir | 6 +-- .../Target/LLVMIR/nvvm/shfl-sync-invalid.mlir | 16 +++++++- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 1e915e3027d58..4d6b0acffe862 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1359,8 +1359,9 @@ def NVVM_ShflOp : a mask for logically splitting warps into sub-segments and an upper bound for clamping the source lane index. - Optionally, `return_value_and_is_valid` can be specified to return a - two-element struct with the result and a predicate indicating if the + The `return_value_and_is_valid` unit attribute can be specified to indicate + that the return value is a two-element struct, where the first element is + the result value and the second element is a predicate indicating if the computed source lane index is valid. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 33f9a256a78d0..1367f63d4b7a3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -867,19 +867,33 @@ LogicalResult MmaOp::verify() { } LogicalResult ShflOp::verify() { - auto type = llvm::dyn_cast(getType()); - - if ((*this)->getAttrOfType("return_value_and_is_valid")) { - auto predicateType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast(type.getBody()[1]) - : nullptr; - if (!predicateType || predicateType.getWidth() != 1) - return emitOpError("expected return type to be a two-element struct with " - "i1 as the second element"); + auto returnStructType = llvm::dyn_cast(getType()); + + if (returnStructType && !getReturnValueAndIsValid()) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when the return type is a struct type"); + + if (getReturnValueAndIsValid()) { + if (!returnStructType || returnStructType.getBody().size() != 2) + return emitOpError("expected return type to be a two-element struct"); + + llvm::ArrayRef returnStruct = returnStructType.getBody(); + + auto resultType = returnStruct[0]; + if (resultType != getVal().getType()) + return emitOpError( + "expected first element in the returned struct to be of type ") + << getVal().getType() << " but got " << resultType << " instead."; + + auto predicateType = returnStruct[1]; + if (!predicateType.isInteger(1)) + return emitOpError("expected second element in the returned struct to be " + "of type 'i1' but got ") + << predicateType << " instead."; } else { - if (type) - return emitOpError("\"return_value_and_is_valid\" attribute must be " - "specified when returning the predicate"); + if (getType() != getVal().getType()) + return emitOpError("expected return type to be of type ") + << getVal().getType() << " but got " << getType() << " instead."; } return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index aaf9f8024bfbe..ba74d4d9585a3 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() { // ----- func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32 } // ----- func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)> } // ----- func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead.}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)> } diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir index d2fe21c841a76..cd65eab977216 100644 --- a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir @@ -3,6 +3,20 @@ // ----- func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}} + // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}} %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> } + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead.}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32 +} + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead.}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)> +} From 9241b6a72aa98b4560226d57c2d1772d8a1afd38 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Mon, 3 Nov 2025 13:12:57 +0000 Subject: [PATCH 6/8] refactor shfl.sync verifier --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 33 +++++++++++++--------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 1367f63d4b7a3..beda3d08d1ba6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -869,31 +869,38 @@ LogicalResult MmaOp::verify() { LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast(getType()); - if (returnStructType && !getReturnValueAndIsValid()) - return emitOpError("\"return_value_and_is_valid\" attribute must be " - "specified when the return type is a struct type"); + auto mismatchedType = [&](Twine desc, Type expectedType, + Type actualType) -> LogicalResult { + return emitOpError("expected " + desc + " to be of type ") + << expectedType << " but got " << actualType << " instead."; + }; + + if (returnStructType) { + if (!getReturnValueAndIsValid()) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when the return type is a struct type"); - if (getReturnValueAndIsValid()) { - if (!returnStructType || returnStructType.getBody().size() != 2) + if (returnStructType.getBody().size() != 2) return emitOpError("expected return type to be a two-element struct"); llvm::ArrayRef returnStruct = returnStructType.getBody(); auto resultType = returnStruct[0]; if (resultType != getVal().getType()) - return emitOpError( - "expected first element in the returned struct to be of type ") - << getVal().getType() << " but got " << resultType << " instead."; + return mismatchedType("first element in the returned struct", + getVal().getType(), resultType); auto predicateType = returnStruct[1]; if (!predicateType.isInteger(1)) - return emitOpError("expected second element in the returned struct to be " - "of type 'i1' but got ") - << predicateType << " instead."; + return mismatchedType("second element in the returned struct", + mlir::IntegerType::get(getContext(), 1), + predicateType); } else { + if (getReturnValueAndIsValid()) + return emitOpError("expected return type to be a two-element struct"); + if (getType() != getVal().getType()) - return emitOpError("expected return type to be of type ") - << getVal().getType() << " but got " << getType() << " instead."; + return mismatchedType("return type", getVal().getType(), getType()); } return success(); } From 4c722d4476365052a16ae106f77798c6bc71a3bf Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 4 Nov 2025 05:08:24 +0000 Subject: [PATCH 7/8] address comments --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index beda3d08d1ba6..afc2161a9acdc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -869,8 +869,8 @@ LogicalResult MmaOp::verify() { LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast(getType()); - auto mismatchedType = [&](Twine desc, Type expectedType, - Type actualType) -> LogicalResult { + auto verifyTypeError = [&](Twine desc, Type expectedType, + Type actualType) -> LogicalResult { return emitOpError("expected " + desc + " to be of type ") << expectedType << " but got " << actualType << " instead."; }; @@ -884,23 +884,22 @@ LogicalResult ShflOp::verify() { return emitOpError("expected return type to be a two-element struct"); llvm::ArrayRef returnStruct = returnStructType.getBody(); - auto resultType = returnStruct[0]; if (resultType != getVal().getType()) - return mismatchedType("first element in the returned struct", - getVal().getType(), resultType); + return verifyTypeError("first element in the returned struct", + getVal().getType(), resultType); auto predicateType = returnStruct[1]; if (!predicateType.isInteger(1)) - return mismatchedType("second element in the returned struct", - mlir::IntegerType::get(getContext(), 1), - predicateType); + return verifyTypeError("second element in the returned struct", + mlir::IntegerType::get(getContext(), 1), + predicateType); } else { if (getReturnValueAndIsValid()) return emitOpError("expected return type to be a two-element struct"); if (getType() != getVal().getType()) - return mismatchedType("return type", getVal().getType(), getType()); + return verifyTypeError("return type", getVal().getType(), getType()); } return success(); } From b686b443513ca6b3f971480d4b7645b3909d4127 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 4 Nov 2025 07:13:44 +0000 Subject: [PATCH 8/8] update punctuation in error message --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index afc2161a9acdc..cfcddd3f71a0d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -872,7 +872,7 @@ LogicalResult ShflOp::verify() { auto verifyTypeError = [&](Twine desc, Type expectedType, Type actualType) -> LogicalResult { return emitOpError("expected " + desc + " to be of type ") - << expectedType << " but got " << actualType << " instead."; + << expectedType << " but got " << actualType << " instead"; }; if (returnStructType) { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index ba74d4d9585a3..49b6342aea538 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -678,7 +678,7 @@ func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 // ----- func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead.}} + // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)> } diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir index cd65eab977216..f2ccfe71a3f23 100644 --- a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir @@ -10,13 +10,13 @@ func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : // ----- func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead.}} + // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}} %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32 } // ----- func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead.}} + // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}} %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)> }