Skip to content

Commit e48a6a8

Browse files
committed
[CoreDSL] fix register initializers > 64 bit
1 parent 412905a commit e48a6a8

File tree

6 files changed

+94
-59
lines changed

6 files changed

+94
-59
lines changed

docs/CoreDSLDialect.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ _Declares a register (file)._
646646
Syntax:
647647

648648
```
649-
operation ::= `coredsl.register` $accessMode oilist(`const` $isConst | `volatile` $isVolatile) $sym_name (`[` $numElements^ `]`)? `` custom<Initializer>($initializer) `:` $regType attr-dict
649+
operation ::= `coredsl.register` $accessMode oilist(`const` $isConst | `volatile` $isVolatile) $sym_name (`[` $numElements^ `]`)? custom<Initializer>($initializer, $regType) attr-dict
650650
```
651651

652652
This operation declares a [CoreDSL register](https://github.com/Minres/CoreDSL/wiki/Structure-and-concepts#registers)
@@ -681,7 +681,7 @@ Interfaces: `GetSetOpInterface`, `Symbol`
681681
<tr><td><code>isConst</code></td><td>::mlir::UnitAttr</td><td>unit attribute</td></tr>
682682
<tr><td><code>isVolatile</code></td><td>::mlir::UnitAttr</td><td>unit attribute</td></tr>
683683
<tr><td><code>numElements</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr>
684-
<tr><td><code>initializer</code></td><td>::mlir::ArrayAttr</td><td>Index array attribute</td></tr>
684+
<tr><td><code>initializer</code></td><td>::mlir::ElementsAttr</td><td>constant vector/tensor attribute</td></tr>
685685
<tr><td><code>regType</code></td><td>::mlir::TypeAttr</td><td>any type attribute</td></tr>
686686
<tr><td><code>accessMode</code></td><td>::mlir::coredsl::RegisterAccessModeAttr</td><td>coredsl.register access mode</td></tr>
687687
</table>

include/shortnail/Dialect/CoreDSL/CoreDSLOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,11 @@ def RegisterAccessMode : I64EnumAttr<
182182
"coredsl.register access mode",
183183
[DefaultRegisterFile, PCRegister, FPRegister, LocalRegister]>;
184184

185-
def IndexArrayAttr
186-
: TypedArrayAttrBase<IndexAttr, "Index array attribute">;
187-
188185
def CoreDSL_RegisterOp :
189186
GetSettableOp<"register", /*prefixAssemblyFormat=*/"$accessMode",
190-
/*suffixAssemblyFormat=*/"(`[` $numElements^ `]`)? `` custom<Initializer>($initializer) `:` $regType attr-dict",
187+
/*suffixAssemblyFormat=*/"(`[` $numElements^ `]`)? custom<Initializer>($initializer, $regType) attr-dict",
191188
/*moreArgs=*/(ins OptionalAttr<IndexAttr>:$numElements,
192-
OptionalAttr<IndexArrayAttr>:$initializer, TypeAttr:$regType,
189+
OptionalAttr<ElementsAttr>:$initializer, TypeAttr:$regType,
193190
RegisterAccessMode:$accessMode)> {
194191
let summary = "Declares a register (file).";
195192
let description = [{

lib/Conversion/CoreDSLToPy/CoreDSLToPy.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -883,17 +883,14 @@ def init_cust_regs():
883883
os << "state[\"cust_regs\"][" << reg.getSymNameAttr() << "] = [0] * "
884884
<< reg.getSize() << "\n";
885885
// Handle init values
886-
if (ArrayAttr initializer = reg.getInitializerAttr()) {
887-
bool regIsSigned = cast<IntegerType>(reg.getRegType()).isSigned();
888-
886+
if (auto initializer = dyn_cast_or_null<DenseIntElementsAttr>(
887+
reg.getInitializerAttr())) {
888+
bool isSigned = cast<IntegerType>(reg.getRegType()).isSigned();
889889
unsigned addr = 0;
890-
for (auto iv : initializer.getAsValueRange<IntegerAttr>()) {
890+
for (auto iv : initializer.getValues<APInt>()) {
891891
os << "state[\"cust_regs\"][" << reg.getSymNameAttr() << "]["
892892
<< addr << "] = int(";
893-
if (regIsSigned)
894-
os << iv.getSExtValue();
895-
else
896-
os << iv.getZExtValue();
893+
iv.print(os, isSigned);
897894
os << ")\n";
898895
++addr;
899896
}
@@ -905,14 +902,11 @@ def init_cust_regs():
905902
for (auto rom : roms) {
906903
os << ROM_PREFIX << rom.getSymName() << " = [";
907904

908-
ArrayAttr initializer = rom.getInitializerAttr();
909-
bool regIsSigned = cast<IntegerType>(rom.getRegType()).isSigned();
910-
for (auto iv : initializer.getAsValueRange<IntegerAttr>()) {
905+
bool isSigned = cast<IntegerType>(rom.getRegType()).isSigned();
906+
auto initializer = cast<DenseIntElementsAttr>(rom.getInitializerAttr());
907+
for (auto iv : initializer.getValues<APInt>()) {
911908
os << "int(";
912-
if (regIsSigned)
913-
os << iv.getSExtValue();
914-
else
915-
os << iv.getZExtValue();
909+
iv.print(os, isSigned);
916910
os << "), ";
917911
}
918912
os << "]\n";

lib/Dialect/CoreDSL/CoreDSLOps.cpp

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
#include "mlir/IR/BuiltinOps.h"
1818
#include "mlir/IR/OpImplementation.h"
1919

20-
#include "llvm/ADT/APSInt.h"
21-
2220
using namespace circt::hwarith;
2321

2422
namespace mlir {
@@ -443,23 +441,17 @@ LogicalResult RegisterOp::verify() {
443441
}
444442

445443
// Initializer checks
446-
if (ArrayAttr initializer = getInitializerAttr()) {
447-
auto initValues = initializer.getAsValueRange<IntegerAttr>();
448-
unsigned initSize = std::distance(initValues.begin(), initValues.end());
449-
if (initSize != getSize())
444+
if (ElementsAttr initializer = getInitializerAttr()) {
445+
if (initializer.getNumElements() != static_cast<int64_t>(getSize()))
450446
return emitError(
451447
"number of elements in initializer does not match register size");
452448

453-
// check that the values do not exceed the type size
454-
unsigned regTypeWidth = getRegType().getIntOrFloatBitWidth();
455-
for (auto iv : initValues) {
456-
unsigned ivWidth =
457-
getAPIntBitWidth(iv, cast<IntegerType>(getRegType()).isSigned());
458-
if (ivWidth > regTypeWidth) {
459-
return emitError("initial value width exceeds register width: ")
460-
<< iv.getSExtValue() << " (" << ivWidth << " bits)";
461-
}
462-
}
449+
if (!isa<DenseIntElementsAttr>(initializer))
450+
return emitOpError("initializer must be a DenseIntElementsAttr");
451+
452+
auto init = cast<DenseIntElementsAttr>(initializer);
453+
if (init.getElementType() != getRegType())
454+
return emitError("initial value type must match the register type");
463455
} else {
464456
if (getIsConst())
465457
return emitError("Const registers must be initialized");
@@ -485,48 +477,84 @@ LogicalResult RegisterOp::verify() {
485477
return success();
486478
}
487479

488-
static ParseResult parseInitializer(OpAsmParser &parser, ArrayAttr &attr) {
489-
auto &builder = parser.getBuilder();
490-
if (failed(parser.parseOptionalEqual()))
491-
// No initializer!
492-
return success();
480+
static ParseResult parseInitializer(OpAsmParser &parser, ElementsAttr &attr,
481+
TypeAttr &regTypeAttr) {
482+
Type regType;
483+
if (failed(parser.parseOptionalEqual())) {
484+
// No initializer, but we still need a type!
485+
auto res = parser.parseColonType(regType);
486+
regTypeAttr = TypeAttr::get(regType);
487+
return res;
488+
}
493489

494-
SmallVector<int64_t> values;
490+
auto valuesLoc = parser.getCurrentLocation();
491+
SmallVector<APInt> values;
495492
auto parseInt = [&]() -> ParseResult {
496-
int64_t v;
493+
APInt v;
497494
auto res = parser.parseOptionalInteger(v);
498495
if (!res.has_value() || failed(*res))
499496
return failure();
500497
values.push_back(v);
501498
return success();
502499
};
503500

504-
if (succeeded(parseInt()) || succeeded(parser.parseCommaSeparatedList(
505-
AsmParser::Delimiter::Square, parseInt))) {
506-
attr = builder.getIndexArrayAttr(values);
501+
if ((succeeded(parseInt()) || succeeded(parser.parseCommaSeparatedList(
502+
AsmParser::Delimiter::Square, parseInt))) &&
503+
succeeded(parser.parseColonType(regType))) {
504+
unsigned targetWidth = regType.getIntOrFloatBitWidth();
505+
bool isSigned = cast<IntegerType>(regType).isSigned();
506+
507+
// Validate ranges and resize APInts to match register type bitwidth
508+
for (auto &v : values) {
509+
unsigned reqWidth = getAPIntBitWidth(v, isSigned);
510+
if (reqWidth > targetWidth) {
511+
auto diag =
512+
parser.emitError(valuesLoc, "initial value width exceeds register "
513+
"width: ");
514+
SmallString<32> valStr;
515+
llvm::raw_svector_ostream valOs(valStr);
516+
v.print(valOs, /*isSigned=*/true);
517+
diag << valStr << " (" << reqWidth << " bits)";
518+
return diag;
519+
}
520+
if (v.getBitWidth() < targetWidth)
521+
v = isSigned ? v.sext(targetWidth) : v.zext(targetWidth);
522+
else if (v.getBitWidth() > targetWidth)
523+
v = v.trunc(targetWidth);
524+
}
525+
526+
auto shapedType =
527+
RankedTensorType::get({static_cast<int64_t>(values.size())}, regType);
528+
attr = DenseIntElementsAttr::get(shapedType, values);
529+
regTypeAttr = TypeAttr::get(regType);
507530
return success();
508531
}
509532

510533
return failure();
511534
}
512535

513-
static void printInitializer(OpAsmPrinter &p, Operation *op, ArrayAttr attr) {
514-
if (!attr)
536+
static void printInitializer(OpAsmPrinter &p, Operation *op, ElementsAttr attr,
537+
TypeAttr regTypeAttr) {
538+
if (!attr) {
539+
p << " : " << regTypeAttr;
515540
return;
541+
}
542+
543+
bool isSigned = !cast<IntegerType>(regTypeAttr.getValue()).isUnsigned();
544+
auto printValue = [&](const APInt &v) { v.print(p.getStream(), isSigned); };
516545

517546
p << " = ";
518-
auto values = attr.getValue();
547+
auto values = attr.getValues<APInt>();
519548

520549
if (values.size() == 1) {
521-
p << cast<IntegerAttr>(values.front()).getAPSInt().getSExtValue();
550+
printValue(*values.begin());
551+
p << " : " << regTypeAttr;
522552
return;
523553
}
524554

525555
p << "[";
526-
llvm::interleaveComma(values, p, [&](Attribute v) {
527-
p << cast<IntegerAttr>(v).getAPSInt().getSExtValue();
528-
});
529-
p << "]";
556+
llvm::interleaveComma(values, p, printValue);
557+
p << "] : " << regTypeAttr;
530558
}
531559

532560
//===----------------------------------------------------------------------===//
@@ -787,11 +815,12 @@ LogicalResult GetOp::canonicalize(GetOp op, PatternRewriter &rewriter) {
787815
auto regOp = cast<RegisterOp>(resolvedSym);
788816
auto initOpt = regOp.getInitializer();
789817
assert(initOpt.has_value());
790-
auto constVal = cast<IntegerAttr>(initOpt->getValue()[initIdx]);
818+
assert(isa<DenseIntElementsAttr>(initOpt.value()));
819+
auto constVal =
820+
cast<DenseIntElementsAttr>(initOpt.value()).getValues<APInt>()[initIdx];
791821
auto resType = op->getResult(0).getType();
792822
rewriter.replaceOpWithNewOp<ConstantOp>(
793-
op, resType,
794-
rewriter.getIntegerAttr(resType, constVal.getValue().getSExtValue()));
823+
op, resType, rewriter.getIntegerAttr(resType, constVal));
795824

796825
return success();
797826
}

test/CoreDSL/illegal_mem_expr.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ coredsl.isax "" {
8484
coredsl.register core_x @ACC6 = 256 : ui8
8585
}
8686

87+
// -----
88+
8789
coredsl.isax "" {
8890
// Legal init value expressions:
8991
// expected-error @+1 {{initial value width exceeds register width}}

test/CoreDSL/registers.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ coredsl.isax "" {
5050

5151
// -----
5252

53+
// Test for arbitrary-width (>64 bit) register initialization values
54+
// CHECK: coredsl.register local @BIG_UNSIGNED = 28446744073709551616 : ui65
55+
// CHECK: coredsl.register local @BIG_SIGNED = -28446744073709551617 : si66
56+
// CHECK: coredsl.register local const @BIG_ARRAY[2] = [18446744073709551616, 36893488147419103231] : ui65
57+
coredsl.isax "" {
58+
coredsl.register core_pc @PC : ui32
59+
coredsl.register local @BIG_UNSIGNED = 28446744073709551616 : ui65
60+
coredsl.register local @BIG_SIGNED = -28446744073709551617 : si66
61+
coredsl.register local const @BIG_ARRAY[2] = [18446744073709551616, 36893488147419103231] : ui65
62+
}
63+
64+
// -----
65+
5366
coredsl.isax "" {
5467
// expected-error @+1 {{Const registers must be initialized}}
5568
coredsl.register local const @C1 : ui32

0 commit comments

Comments
 (0)