Skip to content

Commit 97947f1

Browse files
authored
[MLIR][NVVM] Update Op verifiers to prevent ungraceful exits (#165677)
Updates the following Ops to prevent ungraceful exits with a stack-dump in certain cases of incorrect usages, and instead gracefully error out with a more informative error message: - `tcgen05.ld` - `shfl.sync`
1 parent ab487b6 commit 97947f1

File tree

5 files changed

+78
-14
lines changed

5 files changed

+78
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,9 +1328,9 @@ def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
13281328

13291329
def NVVM_ShflOp :
13301330
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
1331-
Results<(outs LLVM_Type:$res)>,
1331+
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
13321332
Arguments<(ins I32:$thread_mask,
1333-
LLVM_Type:$val,
1333+
AnyTypeOf<[I32, F32]>:$val,
13341334
I32:$offset,
13351335
I32:$mask_and_clamp,
13361336
ShflKindAttr:$kind,
@@ -1346,6 +1346,11 @@ def NVVM_ShflOp :
13461346
a mask for logically splitting warps into sub-segments and an upper bound
13471347
for clamping the source lane index.
13481348

1349+
The `return_value_and_is_valid` unit attribute can be specified to indicate
1350+
that the return value is a two-element struct, where the first element is
1351+
the result value and the second element is a predicate indicating if the
1352+
computed source lane index is valid.
1353+
13491354
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
13501355
}];
13511356
string llvmBuilder = [{

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -920,15 +920,40 @@ LogicalResult MmaOp::verify() {
920920
}
921921

922922
LogicalResult ShflOp::verify() {
923-
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
924-
return success();
925-
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
926-
auto elementType = (type && type.getBody().size() == 2)
927-
? llvm::dyn_cast<IntegerType>(type.getBody()[1])
928-
: nullptr;
929-
if (!elementType || elementType.getWidth() != 1)
930-
return emitError("expected return type to be a two-element struct with "
931-
"i1 as the second element");
923+
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
924+
925+
auto verifyTypeError = [&](Twine desc, Type expectedType,
926+
Type actualType) -> LogicalResult {
927+
return emitOpError("expected " + desc + " to be of type ")
928+
<< expectedType << " but got " << actualType << " instead";
929+
};
930+
931+
if (returnStructType) {
932+
if (!getReturnValueAndIsValid())
933+
return emitOpError("\"return_value_and_is_valid\" attribute must be "
934+
"specified when the return type is a struct type");
935+
936+
if (returnStructType.getBody().size() != 2)
937+
return emitOpError("expected return type to be a two-element struct");
938+
939+
llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
940+
auto resultType = returnStruct[0];
941+
if (resultType != getVal().getType())
942+
return verifyTypeError("first element in the returned struct",
943+
getVal().getType(), resultType);
944+
945+
auto predicateType = returnStruct[1];
946+
if (!predicateType.isInteger(1))
947+
return verifyTypeError("second element in the returned struct",
948+
mlir::IntegerType::get(getContext(), 1),
949+
predicateType);
950+
} else {
951+
if (getReturnValueAndIsValid())
952+
return emitOpError("expected return type to be a two-element struct");
953+
954+
if (getType() != getVal().getType())
955+
return verifyTypeError("return type", getVal().getType(), getType());
956+
}
932957
return success();
933958
}
934959

@@ -2677,6 +2702,9 @@ LogicalResult Tcgen05LdOp::verify() {
26772702
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
26782703
result = emitError("shape 16x32bx2 requires offset argument");
26792704

2705+
if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
2706+
result = emitError("offset argument is only supported for shape 16x32bx2");
2707+
26802708
auto resTy = getRes().getType();
26812709
unsigned resLen = isa<VectorType>(resTy)
26822710
? llvm::cast<VectorType>(resTy).getNumElements()

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() {
664664
// -----
665665

666666
func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
667-
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
667+
// expected-error@+1 {{expected return type to be a two-element struct}}
668668
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
669669
}
670670

671671
// -----
672672

673673
func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
674-
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
674+
// expected-error@+1 {{expected return type to be a two-element struct}}
675675
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
676676
}
677677

678678
// -----
679679

680680
func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
681-
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
681+
// expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}}
682682
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
683683
}
684684

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
// -----
4+
5+
func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
6+
// expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}}
7+
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
8+
}
9+
10+
// -----
11+
12+
func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
13+
// expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}}
14+
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32
15+
}
16+
17+
// -----
18+
19+
func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
20+
// expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}}
21+
%0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)>
22+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
// -----
4+
5+
llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
6+
// expected-error@+1 {{offset argument is only supported for shape 16x32bx2}}
7+
%ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
8+
llvm.return
9+
}

0 commit comments

Comments
 (0)