Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;

def NVVM_ShflOp :
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
Results<(outs LLVM_Type:$res)>,
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
LLVM_Type:$val,
AnyTypeOf<[I32, F32]>:$val,
I32:$offset,
I32:$mask_and_clamp,
ShflKindAttr:$kind,
Expand All @@ -1359,6 +1359,11 @@ def NVVM_ShflOp :
a mask for logically splitting warps into sub-segments and an upper bound
for clamping the source lane index.

The `return_value_and_is_valid` unit attribute can be specified to indicate
that the return value is a two-element struct, where the first element is
the result value and the second element is a predicate indicating if the
computed source lane index is valid.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
}];
string llvmBuilder = [{
Expand Down
40 changes: 31 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,15 +867,34 @@ LogicalResult MmaOp::verify() {
}

LogicalResult ShflOp::verify() {
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
auto elementType = (type && type.getBody().size() == 2)
? llvm::dyn_cast<IntegerType>(type.getBody()[1])
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());

if (returnStructType && !getReturnValueAndIsValid())
return emitOpError("\"return_value_and_is_valid\" attribute must be "
"specified when the return type is a struct type");

if (getReturnValueAndIsValid()) {
if (!returnStructType || returnStructType.getBody().size() != 2)
return emitOpError("expected return type to be a two-element struct");

llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();

auto resultType = returnStruct[0];
if (resultType != getVal().getType())
return emitOpError(
"expected first element in the returned struct to be of type ")
<< getVal().getType() << " but got " << resultType << " instead.";

auto predicateType = returnStruct[1];
if (!predicateType.isInteger(1))
return emitOpError("expected second element in the returned struct to be "
"of type 'i1' but got ")
<< predicateType << " instead.";
} else {
if (getType() != getVal().getType())
return emitOpError("expected return type to be of type ")
<< getVal().getType() << " but got " << getType() << " instead.";
}
return success();
}

Expand Down Expand Up @@ -2451,6 +2470,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");

if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
result = emitError("offset argument is only supported for shape 16x32bx2");

auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() {
// -----

func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
}

// -----

func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
}

// -----

func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead.}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
}

// -----

func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead.}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32
}

// -----

func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead.}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)>
}
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
// expected-error@+1 {{offset argument is only supported for shape 16x32bx2}}
%ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
llvm.return
}