diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4f483859ac18d..4d6b0acffe862 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr; def NVVM_ShflOp : NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>, - Results<(outs LLVM_Type:$res)>, + Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>, Arguments<(ins I32:$thread_mask, - LLVM_Type:$val, + AnyTypeOf<[I32, F32]>:$val, I32:$offset, I32:$mask_and_clamp, ShflKindAttr:$kind, @@ -1359,6 +1359,11 @@ def NVVM_ShflOp : a mask for logically splitting warps into sub-segments and an upper bound for clamping the source lane index. + The `return_value_and_is_valid` unit attribute can be specified to indicate + that the return value is a two-element struct, where the first element is + the result value and the second element is a predicate indicating if the + computed source lane index is valid. + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) }]; string llvmBuilder = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4dbcc1d4b..cfcddd3f71a0d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -867,15 +867,40 @@ LogicalResult MmaOp::verify() { } LogicalResult ShflOp::verify() { - if (!(*this)->getAttrOfType("return_value_and_is_valid")) - return success(); - auto type = llvm::dyn_cast(getType()); - auto elementType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast(type.getBody()[1]) - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); + auto returnStructType = llvm::dyn_cast(getType()); + + auto verifyTypeError = [&](Twine desc, Type expectedType, + Type actualType) -> LogicalResult { + return emitOpError("expected " + desc + " to be of type ") + << expectedType << " but got " << actualType << " instead"; + }; + + if (returnStructType) { + if (!getReturnValueAndIsValid()) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when the return type is a struct type"); + + if (returnStructType.getBody().size() != 2) + return emitOpError("expected return type to be a two-element struct"); + + llvm::ArrayRef returnStruct = returnStructType.getBody(); + auto resultType = returnStruct[0]; + if (resultType != getVal().getType()) + return verifyTypeError("first element in the returned struct", + getVal().getType(), resultType); + + auto predicateType = returnStruct[1]; + if (!predicateType.isInteger(1)) + return verifyTypeError("second element in the returned struct", + mlir::IntegerType::get(getContext(), 1), + predicateType); + } else { + if (getReturnValueAndIsValid()) + return emitOpError("expected return type to be a two-element struct"); + + if (getType() != getVal().getType()) + return verifyTypeError("return type", getVal().getType(), getType()); + } return success(); } @@ -2451,6 +2476,9 @@ LogicalResult Tcgen05LdOp::verify() { if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); + if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) + result = emitError("offset argument is only supported for shape 16x32bx2"); + auto resTy = getRes().getType(); unsigned resLen = isa(resTy) ? llvm::cast(resTy).getNumElements() diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index aaf9f8024bfbe..49b6342aea538 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() { // ----- func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32 } // ----- func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)> } // ----- func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)> } diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir new file mode 100644 index 0000000000000..f2ccfe71a3f23 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> +} + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32 +} + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)> +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir new file mode 100644 index 0000000000000..1b93f20c15b99 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () { + // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}} + %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape} : vector<2 x i32> + llvm.return +}