Skip to content

Commit e712871

Browse files
authored
[MLIR][NVVM] Fix assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp (#163434)
The nvvm dialect instruction PureSpecialRangeableRegisterOp will trigger an assertion failure in LLVM's constant range class when the lower and upper range bounds are equal, but not equal to the integer minimum or max (as required by constant ranges). This requirement is at [line 56 of ConstantRange.cpp](https://llvm.org/doxygen/ConstantRange_8cpp_source.html#l00056): `assert((Lower != Upper || (Lower.isMaxValue() || Lower.isMinValue())) && "Lower == Upper, but they aren't min or max value!");` However, you can write an NVVM dialect operation such as: `%0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32` which triggers this assertion. This change adds a fix to ensure that this requirement is also enforced by NVVM.
1 parent 24a4ad8 commit e712871

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
263263
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
264264
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
265265
let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
266+
let hasVerifier = 1;
266267

267268
// Backwards-compatibility builder for an unspecified range.
268269
let builders = [
@@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
279280
SetIntRangeFn setResultRanges) {
280281
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
281282
}
283+
284+
// Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
285+
::llvm::LogicalResult $cppClass::verify() {
286+
return verifyConstantRangeAttr(getOperation(), getRange());
287+
}
282288
}];
283289

284290
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,6 +2334,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
23342334
}
23352335
}
23362336

2337+
/// Verify the range attribute satisfies LLVM ConstantRange constructor
2338+
/// requirements for NVVM SpecialRangeableRegisterOp.
2339+
static LogicalResult
2340+
verifyConstantRangeAttr(Operation *op,
2341+
std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2342+
if (!rangeAttr)
2343+
return success();
2344+
2345+
const llvm::APInt &lower = rangeAttr->getLower();
2346+
const llvm::APInt &upper = rangeAttr->getUpper();
2347+
2348+
// Check LLVM ConstantRange constructor condition
2349+
if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2350+
unsigned bitWidth = lower.getBitWidth();
2351+
llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2352+
llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2353+
return op->emitOpError(
2354+
"invalid range attribute: Lower == Upper, but they aren't min (")
2355+
<< llvm::toString(minVal, 10, false) << ") or max ("
2356+
<< llvm::toString(maxVal, 10, false)
2357+
<< ") value! This is an invalid constant range.";
2358+
}
2359+
2360+
return success();
2361+
}
2362+
23372363
static llvm::Value *getAsPackedI32(llvm::Value *arg,
23382364
llvm::IRBuilderBase &builder) {
23392365
return builder.CreateBitCast(arg,

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,13 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
567567
%res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
568568
llvm.return
569569
}
570+
571+
572+
// -----
573+
574+
// Test for range validation - invalid range where lower == upper but not at extremes
575+
func.func @invalid_range_equal_bounds() {
576+
// expected-error @below {{invalid range attribute: Lower == Upper, but they aren't min (0) or max (4294967295) value! This is an invalid constant range.}}
577+
%0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
578+
return
579+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ llvm.func @nvvm_special_regs() -> i32 {
152152
%74 = nvvm.read.ptx.sreg.lanemask.ge : i32
153153
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
154154
%75 = nvvm.read.ptx.sreg.lanemask.gt : i32
155+
// CHECK: %76 = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
156+
%76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32
157+
// CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
158+
%77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32
155159
llvm.return %1 : i32
156160
}
157161

0 commit comments

Comments
 (0)