1717#include " mlir/IR/BuiltinOps.h"
1818#include " mlir/IR/OpImplementation.h"
1919
20- #include " llvm/ADT/APSInt.h"
21-
2220using namespace circt ::hwarith;
2321
2422namespace mlir {
@@ -443,23 +441,17 @@ LogicalResult RegisterOp::verify() {
443441 }
444442
445443 // Initializer checks
446- if (ArrayAttr initializer = getInitializerAttr ()) {
447- auto initValues = initializer.getAsValueRange <IntegerAttr>();
448- unsigned initSize = std::distance (initValues.begin (), initValues.end ());
449- if (initSize != getSize ())
444+ if (ElementsAttr initializer = getInitializerAttr ()) {
445+ if (initializer.getNumElements () != static_cast <int64_t >(getSize ()))
450446 return emitError (
451447 " number of elements in initializer does not match register size" );
452448
453- // check that the values do not exceed the type size
454- unsigned regTypeWidth = getRegType ().getIntOrFloatBitWidth ();
455- for (auto iv : initValues) {
456- unsigned ivWidth =
457- getAPIntBitWidth (iv, cast<IntegerType>(getRegType ()).isSigned ());
458- if (ivWidth > regTypeWidth) {
459- return emitError (" initial value width exceeds register width: " )
460- << iv.getSExtValue () << " (" << ivWidth << " bits)" ;
461- }
462- }
449+ if (!isa<DenseIntElementsAttr>(initializer))
450+ return emitOpError (" initializer must be a DenseIntElementsAttr" );
451+
452+ auto init = cast<DenseIntElementsAttr>(initializer);
453+ if (init.getElementType () != getRegType ())
454+ return emitError (" initial value type must match the register type" );
463455 } else {
464456 if (getIsConst ())
465457 return emitError (" Const registers must be initialized" );
@@ -485,48 +477,84 @@ LogicalResult RegisterOp::verify() {
485477 return success ();
486478}
487479
488- static ParseResult parseInitializer (OpAsmParser &parser, ArrayAttr &attr) {
489- auto &builder = parser.getBuilder ();
490- if (failed (parser.parseOptionalEqual ()))
491- // No initializer!
492- return success ();
480+ static ParseResult parseInitializer (OpAsmParser &parser, ElementsAttr &attr,
481+ TypeAttr ®TypeAttr) {
482+ Type regType;
483+ if (failed (parser.parseOptionalEqual ())) {
484+ // No initializer, but we still need a type!
485+ auto res = parser.parseColonType (regType);
486+ regTypeAttr = TypeAttr::get (regType);
487+ return res;
488+ }
493489
494- SmallVector<int64_t > values;
490+ auto valuesLoc = parser.getCurrentLocation ();
491+ SmallVector<APInt> values;
495492 auto parseInt = [&]() -> ParseResult {
496- int64_t v;
493+ APInt v;
497494 auto res = parser.parseOptionalInteger (v);
498495 if (!res.has_value () || failed (*res))
499496 return failure ();
500497 values.push_back (v);
501498 return success ();
502499 };
503500
504- if (succeeded (parseInt ()) || succeeded (parser.parseCommaSeparatedList (
505- AsmParser::Delimiter::Square, parseInt))) {
506- attr = builder.getIndexArrayAttr (values);
501+ if ((succeeded (parseInt ()) || succeeded (parser.parseCommaSeparatedList (
502+ AsmParser::Delimiter::Square, parseInt))) &&
503+ succeeded (parser.parseColonType (regType))) {
504+ unsigned targetWidth = regType.getIntOrFloatBitWidth ();
505+ bool isSigned = cast<IntegerType>(regType).isSigned ();
506+
507+ // Validate ranges and resize APInts to match register type bitwidth
508+ for (auto &v : values) {
509+ unsigned reqWidth = getAPIntBitWidth (v, isSigned);
510+ if (reqWidth > targetWidth) {
511+ auto diag =
512+ parser.emitError (valuesLoc, " initial value width exceeds register "
513+ " width: " );
514+ SmallString<32 > valStr;
515+ llvm::raw_svector_ostream valOs (valStr);
516+ v.print (valOs, /* isSigned=*/ true );
517+ diag << valStr << " (" << reqWidth << " bits)" ;
518+ return diag;
519+ }
520+ if (v.getBitWidth () < targetWidth)
521+ v = isSigned ? v.sext (targetWidth) : v.zext (targetWidth);
522+ else if (v.getBitWidth () > targetWidth)
523+ v = v.trunc (targetWidth);
524+ }
525+
526+ auto shapedType =
527+ RankedTensorType::get ({static_cast <int64_t >(values.size ())}, regType);
528+ attr = DenseIntElementsAttr::get (shapedType, values);
529+ regTypeAttr = TypeAttr::get (regType);
507530 return success ();
508531 }
509532
510533 return failure ();
511534}
512535
513- static void printInitializer (OpAsmPrinter &p, Operation *op, ArrayAttr attr) {
514- if (!attr)
536+ static void printInitializer (OpAsmPrinter &p, Operation *op, ElementsAttr attr,
537+ TypeAttr regTypeAttr) {
538+ if (!attr) {
539+ p << " : " << regTypeAttr;
515540 return ;
541+ }
542+
543+ bool isSigned = !cast<IntegerType>(regTypeAttr.getValue ()).isUnsigned ();
544+ auto printValue = [&](const APInt &v) { v.print (p.getStream (), isSigned); };
516545
517546 p << " = " ;
518- auto values = attr.getValue ();
547+ auto values = attr.getValues <APInt> ();
519548
520549 if (values.size () == 1 ) {
521- p << cast<IntegerAttr>(values.front ()).getAPSInt ().getSExtValue ();
550+ printValue (*values.begin ());
551+ p << " : " << regTypeAttr;
522552 return ;
523553 }
524554
525555 p << " [" ;
526- llvm::interleaveComma (values, p, [&](Attribute v) {
527- p << cast<IntegerAttr>(v).getAPSInt ().getSExtValue ();
528- });
529- p << " ]" ;
556+ llvm::interleaveComma (values, p, printValue);
557+ p << " ] : " << regTypeAttr;
530558}
531559
532560// ===----------------------------------------------------------------------===//
@@ -787,11 +815,12 @@ LogicalResult GetOp::canonicalize(GetOp op, PatternRewriter &rewriter) {
787815 auto regOp = cast<RegisterOp>(resolvedSym);
788816 auto initOpt = regOp.getInitializer ();
789817 assert (initOpt.has_value ());
790- auto constVal = cast<IntegerAttr>(initOpt->getValue ()[initIdx]);
818+ assert (isa<DenseIntElementsAttr>(initOpt.value ()));
819+ auto constVal =
820+ cast<DenseIntElementsAttr>(initOpt.value ()).getValues <APInt>()[initIdx];
791821 auto resType = op->getResult (0 ).getType ();
792822 rewriter.replaceOpWithNewOp <ConstantOp>(
793- op, resType,
794- rewriter.getIntegerAttr (resType, constVal.getValue ().getSExtValue ()));
823+ op, resType, rewriter.getIntegerAttr (resType, constVal));
795824
796825 return success ();
797826}
0 commit comments