Skip to content

Commit 25a592c

Browse files
authored
[MLIR][NVVM] Update redux.sync op (#166125)
This change: - Updates the `redux.sync` NVVM Op input and output type constraints. - Adds a verifier for the Op to prevent stack dumps and hitting `llvm_unreachable` in certain invalid usage scenarios. Instead, we gracefully error out with an informative message now.
1 parent 04619db commit 25a592c

File tree

4 files changed

+91
-6
lines changed

4 files changed

+91
-6
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
476476
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
477477

478478
def NVVM_ReduxOp :
479-
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
480-
Results<(outs LLVM_Type:$res)>,
481-
Arguments<(ins LLVM_Type:$val,
479+
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>, AllTypesMatch<["res", "val"]>]>,
480+
Results<(outs AnyTypeOf<[I32, F32]>:$res)>,
481+
Arguments<(ins AnyTypeOf<[I32, F32]>:$val,
482482
ReduxKindAttr:$kind,
483483
I32:$mask_and_clamp,
484484
DefaultValuedAttr<BoolAttr, "false">:$abs,
@@ -496,6 +496,8 @@ def NVVM_ReduxOp :
496496

497497
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync)
498498
}];
499+
let hasVerifier = 1;
500+
499501
string llvmBuilder = [{
500502
auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
501503
$res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
16301630
return success();
16311631
}
16321632

1633+
LogicalResult NVVM::ReduxOp::verify() {
1634+
mlir::Type reduxType = getType();
1635+
1636+
if (!reduxType.isF32()) {
1637+
if (getAbs())
1638+
return emitOpError("abs attribute is supported only for f32 type");
1639+
if (getNan())
1640+
return emitOpError("nan attribute is supported only for f32 type");
1641+
}
1642+
1643+
NVVM::ReduxKind kind = getKind();
1644+
switch (kind) {
1645+
case NVVM::ReduxKind::ADD:
1646+
case NVVM::ReduxKind::AND:
1647+
case NVVM::ReduxKind::OR:
1648+
case NVVM::ReduxKind::XOR:
1649+
case NVVM::ReduxKind::MAX:
1650+
case NVVM::ReduxKind::MIN:
1651+
case NVVM::ReduxKind::UMAX:
1652+
case NVVM::ReduxKind::UMIN:
1653+
if (!reduxType.isInteger(32))
1654+
return emitOpError("'")
1655+
<< stringifyEnum(kind) << "' redux kind unsupported with "
1656+
<< reduxType << " type. Only supported type is 'i32'.";
1657+
break;
1658+
case NVVM::ReduxKind::FMIN:
1659+
case NVVM::ReduxKind::FMAX:
1660+
if (!reduxType.isF32())
1661+
return emitOpError("'")
1662+
<< stringifyEnum(kind) << "' redux kind unsupported with "
1663+
<< reduxType << " type. Only supported type is 'f32'.";
1664+
break;
1665+
}
1666+
1667+
return success();
1668+
}
1669+
16331670
/// Packs the given `field` into the `result`.
16341671
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
16351672
static llvm::Value *

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ using mlir::LLVM::detail::createIntrinsicCall;
3636
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
3737
NVVM::ReduxKind kind,
3838
bool hasAbs, bool hasNaN) {
39-
if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
40-
llvm_unreachable("unsupported data type for redux");
41-
4239
switch (kind) {
4340
case NVVM::ReduxKind::ADD:
4441
return llvm::Intrinsic::nvvm_redux_sync_add;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
// -----
4+
5+
llvm.func @redux_sync_i32_with_abs(%value: i32, %offset: i32) {
6+
// expected-error@+1 {{abs attribute is supported only for f32 type}}
7+
%res = nvvm.redux.sync add %value, %offset {abs = true}: i32 -> i32
8+
llvm.return
9+
}
10+
11+
// -----
12+
13+
llvm.func @redux_sync_i32_with_nan(%value: i32, %offset: i32) {
14+
// expected-error@+1 {{nan attribute is supported only for f32 type}}
15+
%res = nvvm.redux.sync add %value, %offset {nan = true}: i32 -> i32
16+
llvm.return
17+
}
18+
19+
// -----
20+
21+
llvm.func @redux_sync_f32_with_invalid_kind_add(%value: f32, %offset: i32) {
22+
// expected-error@+1 {{'add' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}}
23+
%res = nvvm.redux.sync add %value, %offset: f32 -> f32
24+
llvm.return
25+
}
26+
27+
// -----
28+
29+
llvm.func @redux_sync_f32_with_invalid_kind_and(%value: f32, %offset: i32) {
30+
// expected-error@+1 {{'and' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}}
31+
%res = nvvm.redux.sync and %value, %offset: f32 -> f32
32+
llvm.return
33+
}
34+
35+
// -----
36+
37+
llvm.func @redux_sync_i32_with_invalid_kind_fmin(%value: i32, %offset: i32) {
38+
// expected-error@+1 {{'fmin' redux kind unsupported with 'i32' type. Only supported type is 'f32'.}}
39+
%res = nvvm.redux.sync fmin %value, %offset: i32 -> i32
40+
llvm.return
41+
}
42+
43+
// -----
44+
45+
llvm.func @redux_sync_non_matching_types(%value: i32, %offset: i32) {
46+
// expected-error@+1 {{failed to verify that all of {res, val} have same type}}
47+
%res = nvvm.redux.sync add %value, %offset: i32 -> f32
48+
llvm.return
49+
}

0 commit comments

Comments
 (0)