Skip to content

Commit 5e705e9

Browse files
committed
[MLIR] Make compatible with APInt ctor assertion
This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag. The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future. Note that the assertion is currently still disabled by default, so this patch is mostly NFC.
1 parent f445e39 commit 5e705e9

File tree

9 files changed

+26
-13
lines changed

9 files changed

+26
-13
lines changed

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
701701
return $_get(type.getContext(), type, apValue);
702702
}
703703

704+
// TODO: Avoid implicit trunc?
704705
IntegerType intTy = ::llvm::cast<IntegerType>(type);
705-
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
706+
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
707+
/*implicitTrunc=*/true);
706708
return $_get(type.getContext(), type, apValue);
707709
}]>
708710
];

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,8 @@ class AsmParser {
749749
// zero for non-negated integers.
750750
result =
751751
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
752-
if (APInt(uintResult.getBitWidth(), result) != uintResult)
752+
if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
753+
/*implicitTrunc=*/true) != uintResult)
753754
return emitError(loc, "integer value too large");
754755
return success();
755756
}

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
4343
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
4444
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
4545
Type eTy = shapedTy.getElementType();
46-
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
46+
APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
4747
return DenseIntElementsAttr::get(shapedTy, valueInt);
4848
}
4949

mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
528528
int64_t value = 0;
529529
if (failed(parser.parseInteger(value)))
530530
return failure();
531-
values.push_back(APInt(bitWidth, value));
531+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
532532

533533
Block *destination;
534534
SmallVector<OpAsmParser::UnresolvedOperand> operands;

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
598598
int64_t value = 0;
599599
if (failed(parser.parseInteger(value)))
600600
return failure();
601-
values.push_back(APInt(bitWidth, value));
601+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
602602

603603
Block *destination;
604604
SmallVector<OpAsmParser::UnresolvedOperand> operands;

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
10731073
if (parser.parseInteger(value))
10741074
return failure();
10751075
shapeTmp++;
1076-
values.push_back(APInt(32, value));
1076+
values.push_back(APInt(32, value, /*isSigned=*/true));
10771077
return success();
10781078
};
10791079

mlir/lib/IR/Builders.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
234234
}
235235

236236
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
237-
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
237+
// The APInt always uses isSigned=true here because we accept the value
238+
// as int32_t.
239+
return IntegerAttr::get(getIntegerType(32),
240+
APInt(32, value, /*isSigned=*/true));
238241
}
239242

240243
IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
252255
}
253256

254257
IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
255-
return IntegerAttr::get(getIntegerType(8), APInt(8, value));
258+
// The APInt always uses isSigned=true here because we accept the value
259+
// as int8_t.
260+
return IntegerAttr::get(getIntegerType(8),
261+
APInt(8, value, /*isSigned=*/true));
256262
}
257263

258264
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
259265
if (type.isIndex())
260266
return IntegerAttr::get(type, APInt(64, value));
261-
return IntegerAttr::get(
262-
type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
267+
// TODO: Avoid implicit trunc?
268+
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
269+
type.isSignedInteger(),
270+
/*implicitTrunc=*/true));
263271
}
264272

265273
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,9 +1284,11 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
12841284
uint32_t word1;
12851285
uint32_t word2;
12861286
} words = {operands[2], operands[3]};
1287-
value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1287+
value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true,
1288+
/*implicitTrunc=*/true);
12881289
} else if (bitwidth <= 32) {
1289-
value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1290+
value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1291+
/*implicitTrunc=*/true);
12901292
}
12911293

12921294
auto attr = opBuilder.getIntegerAttr(intType, value);

mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
176176
IntegerType::get(&context, 16, IntegerType::Signless);
177177
auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
178178
// Check the bit extension of same value under different signedness semantics.
179-
APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
179+
APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
180180
signlessInt16Type.getSignedness());
181181
APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
182182
signedInt16Type.getSignedness());

0 commit comments

Comments
 (0)