Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;

def NVVM_ShflOp :
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
Results<(outs LLVM_Type:$res)>,
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
LLVM_Type:$val,
AnyTypeOf<[I32, F32]>:$val,
I32:$offset,
I32:$mask_and_clamp,
ShflKindAttr:$kind,
Expand All @@ -1359,6 +1359,10 @@ def NVVM_ShflOp :
a mask for logically splitting warps into sub-segments and an upper bound
for clamping the source lane index.

Optionally, `return_value_and_is_valid` can be specified to return a
two-element struct with the result and a predicate indicating if the
computed source lane index is valid.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
}];
string llvmBuilder = [{
Expand Down
24 changes: 16 additions & 8 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,15 +867,20 @@ LogicalResult MmaOp::verify() {
}

LogicalResult ShflOp::verify() {
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
auto elementType = (type && type.getBody().size() == 2)
? llvm::dyn_cast<IntegerType>(type.getBody()[1])
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");

if ((*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) {
auto predicateType = (type && type.getBody().size() == 2)
? llvm::dyn_cast<IntegerType>(type.getBody()[1])
: nullptr;
if (!predicateType || predicateType.getWidth() != 1)
return emitOpError("expected return type to be a two-element struct with "
"i1 as the second element");
} else {
if (type)
return emitOpError("\"return_value_and_is_valid\" attribute must be "
"specified when returning the predicate");
}
return success();
}

Expand Down Expand Up @@ -2451,6 +2456,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");

if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
result = emitError("offset argument is only supported for shape 16x32bx2");

auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}}
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
}
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

// -----

llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
// expected-error@+1 {{offset argument is only supported for shape 16x32bx2}}
%ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
llvm.return
}