Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;

def NVVM_ShflOp :
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
Results<(outs LLVM_Type:$res)>,
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
LLVM_Type:$val,
AnyTypeOf<[I32, F32]>:$val,
I32:$offset,
I32:$mask_and_clamp,
ShflKindAttr:$kind,
Expand All @@ -1359,6 +1359,11 @@ def NVVM_ShflOp :
a mask for logically splitting warps into sub-segments and an upper bound
for clamping the source lane index.

The `return_value_and_is_valid` unit attribute can be specified to indicate
that the return value is a two-element struct, where the first element is
the result value and the second element is a predicate indicating if the
computed source lane index is valid.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
}];
string llvmBuilder = [{
Expand Down
46 changes: 37 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,15 +867,40 @@ LogicalResult MmaOp::verify() {
}

LogicalResult ShflOp::verify() {
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
auto elementType = (type && type.getBody().size() == 2)
? llvm::dyn_cast<IntegerType>(type.getBody()[1])
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());

auto verifyTypeError = [&](Twine desc, Type expectedType,
Type actualType) -> LogicalResult {
return emitOpError("expected " + desc + " to be of type ")
<< expectedType << " but got " << actualType << " instead";
};

if (returnStructType) {
if (!getReturnValueAndIsValid())
return emitOpError("\"return_value_and_is_valid\" attribute must be "
"specified when the return type is a struct type");

if (returnStructType.getBody().size() != 2)
return emitOpError("expected return type to be a two-element struct");

llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
auto resultType = returnStruct[0];
if (resultType != getVal().getType())
return verifyTypeError("first element in the returned struct",
getVal().getType(), resultType);

auto predicateType = returnStruct[1];
if (!predicateType.isInteger(1))
return verifyTypeError("second element in the returned struct",
mlir::IntegerType::get(getContext(), 1),
predicateType);
} else {
if (getReturnValueAndIsValid())
return emitOpError("expected return type to be a two-element struct");

if (getType() != getVal().getType())
return verifyTypeError("return type", getVal().getType(), getType());
}
return success();
}

Expand Down Expand Up @@ -2451,6 +2476,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");

if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
result = emitError("offset argument is only supported for shape 16x32bx2");

auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() {
// -----

func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
}

// -----

func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
}

// -----

func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
// expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
}

// -----

func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32
}

// -----

func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)>
}
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
// expected-error@+1 {{offset argument is only supported for shape 16x32bx2}}
%ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
llvm.return
}
Loading