Skip to content

Commit 80473e8

Browse files
xlaukolanza
authored andcommitted
[CIR] Clean up IntAttr (llvm#1725)
- Add common CIR_ prefix - Simplify printing/parsing - Make it use IntTypeInterface
1 parent 3d954ce commit 80473e8

File tree

7 files changed

+123
-71
lines changed

7 files changed

+123
-71
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
4646
mlir::Value getConstAPSInt(mlir::Location loc, const llvm::APSInt &val) {
4747
auto ty =
4848
cir::IntType::get(getContext(), val.getBitWidth(), val.isSigned());
49-
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(ty, val));
49+
return create<cir::ConstantOp>(loc, cir::IntAttr::get(ty, val));
5050
}
5151

5252
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits) {
@@ -63,7 +63,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
6363

6464
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
6565
const llvm::APInt &val) {
66-
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(typ, val));
66+
return create<cir::ConstantOp>(loc, cir::IntAttr::get(typ, val));
6767
}
6868

6969
cir::ConstantOp getConstant(mlir::Location loc, mlir::TypedAttr attr) {

clang/include/clang/CIR/Dialect/IR/CIRAttrs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
2525

2626
#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
27+
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
2728

2829
//===----------------------------------------------------------------------===//
2930
// CIR Dialect Attrs

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -406,32 +406,71 @@ def ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record",
406406
// IntegerAttr
407407
//===----------------------------------------------------------------------===//
408408

409-
def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
409+
def CIR_IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
410410
let summary = "An Attribute containing a integer value";
411411
let description = [{
412412
An integer attribute is a literal attribute that represents an integral
413413
value of the specified integer type.
414414
}];
415-
let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
415+
416+
let parameters = (ins
417+
AttributeSelfTypeParameter<"", "cir::IntTypeInterface">:$type,
418+
APIntParameter<"">:$value
419+
);
420+
416421
let builders = [
417422
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
418423
"const llvm::APInt &":$value), [{
419-
return $_get(type.getContext(), type, value);
424+
return $_get(type.getContext(),
425+
mlir::cast<cir::IntTypeInterface>(type), value);
420426
}]>,
421427
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "int64_t":$value), [{
422-
IntType intType = mlir::cast<IntType>(type);
423-
llvm::APInt apValue(intType.getWidth(), value, intType.isSigned());
428+
auto intType = mlir::cast<cir::IntTypeInterface>(type);
429+
mlir::APInt apValue(intType.getWidth(), value, intType.isSigned());
424430
return $_get(intType.getContext(), intType, apValue);
425431
}]>,
426432
];
433+
427434
let extraClassDeclaration = [{
428-
int64_t getSInt() const { return getValue().getSExtValue(); }
429-
uint64_t getUInt() const { return getValue().getZExtValue(); }
430-
bool isNullValue() const { return getValue() == 0; }
431-
uint64_t getBitWidth() const { return mlir::cast<IntType>(getType()).getWidth(); }
435+
int64_t getSInt() const;
436+
uint64_t getUInt() const;
437+
bool isNullValue() const;
438+
bool isSigned() const;
439+
bool isUnsigned() const;
440+
uint64_t getBitWidth() const;
441+
}];
442+
443+
let extraClassDefinition = [{
444+
int64_t $cppClass::getSInt() const {
445+
return getValue().getSExtValue();
446+
}
447+
448+
uint64_t $cppClass::getUInt() const {
449+
return getValue().getZExtValue();
450+
}
451+
452+
bool $cppClass::isNullValue() const {
453+
return getValue() == 0;
454+
}
455+
456+
bool $cppClass::isSigned() const {
457+
return mlir::cast<IntTypeInterface>(getType()).isSigned();
458+
}
459+
460+
bool $cppClass::isUnsigned() const {
461+
return mlir::cast<IntTypeInterface>(getType()).isUnsigned();
462+
}
463+
464+
uint64_t $cppClass::getBitWidth() const {
465+
return mlir::cast<IntTypeInterface>(getType()).getWidth();
466+
}
467+
}];
468+
469+
let assemblyFormat = [{
470+
`<` custom<IntLiteral>($value, ref($type)) `>`
432471
}];
472+
433473
let genVerifyDecl = 1;
434-
let hasCustomAssemblyFormat = 1;
435474
}
436475

437476
//===----------------------------------------------------------------------===//
@@ -892,7 +931,7 @@ def DynamicCastInfoAttr
892931
GlobalViewAttr:$destRtti,
893932
"mlir::FlatSymbolRefAttr":$runtimeFunc,
894933
"mlir::FlatSymbolRefAttr":$badCastFunc,
895-
IntAttr:$offsetHint);
934+
CIR_IntAttr:$offsetHint);
896935

897936
let builders = [
898937
AttrBuilderWithInferredContext<(ins "GlobalViewAttr":$srcRtti,

clang/lib/CIR/CodeGen/CIRGenExprConst.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ bool ConstantAggregateBuilder::addBits(llvm::APInt Bits, uint64_t OffsetInBits,
271271

272272
if (*FirstElemToUpdate == *LastElemToUpdate || isNull) {
273273
// All existing bits are either zero or undef.
274-
add(CGM.getBuilder().getAttr<cir::IntAttr>(charTy, BitsThisChar),
275-
OffsetInChars, /*AllowOverwrite*/ true);
274+
add(cir::IntAttr::get(charTy, BitsThisChar), OffsetInChars,
275+
/*AllowOverwrite*/ true);
276276
} else {
277277
cir::IntAttr CI = dyn_cast<cir::IntAttr>(Elems[*FirstElemToUpdate]);
278278
// In order to perform a partial update, we need the existing bitwise
@@ -286,8 +286,7 @@ bool ConstantAggregateBuilder::addBits(llvm::APInt Bits, uint64_t OffsetInBits,
286286
assert((!(CI.getValue() & UpdateMask) || AllowOverwrite) &&
287287
"unexpectedly overwriting bitfield");
288288
BitsThisChar |= (CI.getValue() & ~UpdateMask);
289-
Elems[*FirstElemToUpdate] =
290-
CGM.getBuilder().getAttr<cir::IntAttr>(charTy, BitsThisChar);
289+
Elems[*FirstElemToUpdate] = cir::IntAttr::get(charTy, BitsThisChar);
291290
}
292291
}
293292

@@ -1907,7 +1906,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
19071906
if (mlir::isa<cir::BoolType>(ty))
19081907
return builder.getCIRBoolAttr(Value.getInt().getZExtValue());
19091908
assert(mlir::isa<cir::IntType>(ty) && "expected integral type");
1910-
return CGM.getBuilder().getAttr<cir::IntAttr>(ty, Value.getInt());
1909+
return cir::IntAttr::get(ty, Value.getInt());
19111910
}
19121911
case APValue::Float: {
19131912
const llvm::APFloat &Init = Value.getFloat();
@@ -2018,8 +2017,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
20182017
llvm::APSInt real = Value.getComplexIntReal();
20192018
llvm::APSInt imag = Value.getComplexIntImag();
20202019
return builder.getAttr<cir::ComplexAttr>(
2021-
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
2022-
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
2020+
complexType, cir::IntAttr::get(complexElemTy, real),
2021+
cir::IntAttr::get(complexElemTy, imag));
20232022
}
20242023

20252024
assert(isa<cir::FPTypeInterface>(complexElemTy) &&

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
178178
// Leaves.
179179
mlir::Value VisitIntegerLiteral(const IntegerLiteral *E) {
180180
mlir::Type Ty = CGF.convertType(E->getType());
181-
return Builder.create<cir::ConstantOp>(
182-
CGF.getLoc(E->getExprLoc()),
183-
Builder.getAttr<cir::IntAttr>(Ty, E->getValue()));
181+
return Builder.getConstAPInt(CGF.getLoc(E->getExprLoc()), Ty,
182+
E->getValue());
184183
}
185184

186185
mlir::Value VisitFixedPointLiteral(const FixedPointLiteral *E) {

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,40 @@
2626
#include "mlir/Support/LLVM.h"
2727
#include "mlir/Support/LogicalResult.h"
2828

29+
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
2930
#include "llvm/ADT/STLExtras.h"
3031
#include "llvm/ADT/TypeSwitch.h"
3132

3233
// ClangIR holds back AST references when available.
3334
#include "clang/AST/Decl.h"
3435
#include "clang/AST/DeclCXX.h"
3536
#include "clang/AST/ExprCXX.h"
37+
#include "llvm/Support/ErrorHandling.h"
38+
#include "llvm/Support/SMLoc.h"
39+
40+
//===-----------------------------------------------------------------===//
41+
// RecordMembers
42+
//===-----------------------------------------------------------------===//
3643

3744
static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
3845
static mlir::ParseResult parseRecordMembers(::mlir::AsmParser &parser,
3946
mlir::ArrayAttr &members);
4047

48+
//===-----------------------------------------------------------------===//
49+
// IntLiteral
50+
//===-----------------------------------------------------------------===//
51+
52+
static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
53+
cir::IntTypeInterface ty);
54+
55+
static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
56+
llvm::APInt &value,
57+
cir::IntTypeInterface ty);
58+
59+
//===-----------------------------------------------------------------===//
60+
// FloatLiteral
61+
//===-----------------------------------------------------------------===//
62+
4163
static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
4264
mlir::Type ty);
4365
static mlir::ParseResult
@@ -204,65 +226,52 @@ static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
204226
// IntAttr definitions
205227
//===----------------------------------------------------------------------===//
206228

207-
Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
208-
mlir::APInt APValue;
229+
template <typename IntT>
230+
static bool isTooLargeForType(const mlir::APInt &v, IntT expectedValue) {
231+
if constexpr (std::is_signed_v<IntT>) {
232+
return v.getSExtValue() != expectedValue;
233+
} else {
234+
return v.getZExtValue() != expectedValue;
235+
}
236+
}
209237

210-
if (!mlir::isa<IntType>(odsType))
211-
return {};
212-
auto type = mlir::cast<IntType>(odsType);
238+
template <typename IntT>
239+
static ParseResult parseIntLiteralImpl(mlir::AsmParser &p, llvm::APInt &value,
240+
cir::IntTypeInterface ty) {
241+
IntT ivalue;
242+
const bool isSigned = ty.isSigned();
243+
if (p.parseInteger(ivalue))
244+
return p.emitError(p.getCurrentLocation(), "expected integer value");
213245

214-
// Consume the '<' symbol.
215-
if (parser.parseLess())
216-
return {};
246+
value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
247+
if (isTooLargeForType(value, ivalue))
217248

218-
// Fetch arbitrary precision integer value.
219-
if (type.isSigned()) {
220-
int64_t value;
221-
if (parser.parseInteger(value))
222-
parser.emitError(parser.getCurrentLocation(), "expected integer value");
223-
APValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
224-
/*implicitTrunc=*/true);
225-
if (APValue.getSExtValue() != value)
226-
parser.emitError(parser.getCurrentLocation(),
249+
return p.emitError(p.getCurrentLocation(),
227250
"integer value too large for the given type");
228-
} else {
229-
uint64_t value;
230-
if (parser.parseInteger(value))
231-
parser.emitError(parser.getCurrentLocation(), "expected integer value");
232-
APValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
233-
/*implicitTrunc=*/true);
234-
if (APValue.getZExtValue() != value)
235-
parser.emitError(parser.getCurrentLocation(),
236-
"integer value too large for the given type");
237-
}
238251

239-
// Consume the '>' symbol.
240-
if (parser.parseGreater())
241-
return {};
252+
return success();
253+
}
242254

243-
return IntAttr::get(parser.getContext(), type, APValue);
255+
mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
256+
cir::IntTypeInterface ty) {
257+
if (ty.isSigned())
258+
return parseIntLiteralImpl<int64_t>(parser, value, ty);
259+
return parseIntLiteralImpl<uint64_t>(parser, value, ty);
244260
}
245261

246-
void IntAttr::print(AsmPrinter &printer) const {
247-
auto type = mlir::cast<IntType>(getType());
248-
printer << '<';
249-
if (type.isSigned())
250-
printer << getSInt();
262+
void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
263+
cir::IntTypeInterface ty) {
264+
if (ty.isSigned())
265+
p << value.getSExtValue();
251266
else
252-
printer << getUInt();
253-
printer << '>';
267+
p << value.getZExtValue();
254268
}
255269

256270
LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
257-
Type type, APInt value) {
258-
if (!mlir::isa<IntType>(type))
259-
return emitError() << "expected 'simple.int' type";
260-
261-
auto intType = mlir::cast<IntType>(type);
262-
if (value.getBitWidth() != intType.getWidth())
271+
cir::IntTypeInterface type, llvm::APInt value) {
272+
if (value.getBitWidth() != type.getWidth())
263273
return emitError() << "type and value bitwidth mismatch: "
264-
<< intType.getWidth() << " != " << value.getBitWidth();
265-
274+
<< type.getWidth() << " != " << value.getBitWidth();
266275
return success();
267276
}
268277

@@ -481,8 +490,8 @@ LogicalResult
481490
GlobalAnnotationValuesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
482491
mlir::ArrayAttr annotations) {
483492
if (annotations.empty())
484-
return emitError()
485-
<< "GlobalAnnotationValuesAttr should at least have one annotation";
493+
return emitError() << "GlobalAnnotationValuesAttr should at least have "
494+
"one annotation";
486495

487496
for (auto &entry : annotations) {
488497
auto annoEntry = ::mlir::dyn_cast<mlir::ArrayAttr>(entry);

clang/test/CIR/IR/invalid.cir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,11 @@ module {
617617
module {
618618
// expected-error@below {{integer value too large for the given type}}
619619
cir.global external @a = #cir.int<256> : !cir.int<u, 8>
620+
}
621+
622+
// -----
623+
624+
module {
620625
// expected-error@below {{integer value too large for the given type}}
621626
cir.global external @b = #cir.int<-129> : !cir.int<s, 8>
622627
}

0 commit comments

Comments
 (0)