Skip to content

Commit a3b2878

Browse files
authored
[FIRRTL] Remove convention verification for InstanceChoice (#10066)
* Remove verification that requires referred modules have same conventions in InstanceChoice * Force scalarized convention in LowerTypes/LowerSignatures for referred modules of InstanceChoice
1 parent 5df8c85 commit a3b2878

File tree

8 files changed

+35
-44
lines changed

8 files changed

+35
-44
lines changed

include/circt/Dialect/FIRRTL/FIRRTLUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef CIRCT_DIALECT_FIRRTL_FIRRTLUTILS_H
1414
#define CIRCT_DIALECT_FIRRTL_FIRRTLUTILS_H
1515

16+
#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1617
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
1718
#include "mlir/IR/BuiltinOps.h"
1819

lib/Dialect/FIRRTL/FIRRTLOps.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3165,7 +3165,6 @@ LogicalResult InstanceChoiceOp::verify() {
31653165
LogicalResult
31663166
InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
31673167
auto caseNames = getCaseNamesAttr();
3168-
std::optional<Convention> convention;
31693168
for (auto moduleName : getModuleNamesAttr()) {
31703169
auto moduleNameRef = cast<FlatSymbolRefAttr>(moduleName);
31713170
if (failed(instance_like_impl::verifyReferencedModule(*this, symbolTable,
@@ -3178,14 +3177,6 @@ InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
31783177
if (isa<FIntModuleOp>(referencedModule))
31793178
return emitOpError("intmodule must be instantiated with instance op, "
31803179
"not via 'firrtl.instance_choice'");
3181-
3182-
if (!convention) {
3183-
convention = referencedModule.getConvention();
3184-
continue;
3185-
}
3186-
3187-
if (*convention != referencedModule.getConvention())
3188-
return emitOpError("all modules must have the same convention");
31893180
}
31903181

31913182
auto root = cast<SymbolRefAttr>(caseNames[0]).getRootReference();

lib/Dialect/FIRRTL/FIRRTLUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
14+
#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1415
#include "circt/Dialect/HW/HWOps.h"
1516
#include "circt/Dialect/HW/InnerSymbolNamespace.h"
1617
#include "circt/Dialect/Seq/SeqTypes.h"

lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "circt/Support/InstanceGraphInterface.h"
2323
#include "mlir/IR/Threading.h"
2424
#include "mlir/Pass/Pass.h"
25+
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/Support/Debug.h"
2627

2728
#define DEBUG_TYPE "firrtl-lower-signatures"
@@ -508,6 +509,7 @@ struct LowerSignaturesPass
508509
// This is the main entrypoint for the lowering pass.
509510
void LowerSignaturesPass::runOnOperation() {
510511
CIRCT_DEBUG_SCOPED_PASS_LOGGER(this);
512+
auto &instanceGraph = getAnalysis<InstanceGraph>();
511513

512514
// Cached attr
513515
AttrCache cache(&getContext());
@@ -516,8 +518,15 @@ void LowerSignaturesPass::runOnOperation() {
516518
auto circuit = getOperation();
517519

518520
for (auto mod : circuit.getOps<FModuleLike>()) {
519-
if (lowerModuleSignature(mod, mod.getConvention(), cache,
520-
portMap[mod.getNameAttr()])
521+
auto convention = mod.getConvention();
522+
// Instance choices select between modules with a shared port shape, so
523+
// any module instantiated by one must use the scalarized convention.
524+
if (llvm::any_of(instanceGraph.lookup(mod)->uses(),
525+
[](InstanceRecord *use) {
526+
return use->getInstance<InstanceChoiceOp>();
527+
}))
528+
convention = Convention::Scalarized;
529+
if (lowerModuleSignature(mod, convention, cache, portMap[mod.getNameAttr()])
521530
.failed())
522531
return signalPassFailure();
523532
}

lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "mlir/Pass/Pass.h"
4242
#include "llvm/ADT/APSInt.h"
4343
#include "llvm/ADT/BitVector.h"
44+
#include "llvm/ADT/STLExtras.h"
4445
#include "llvm/Support/Debug.h"
4546

4647
#define DEBUG_TYPE "firrtl-lower-types"
@@ -1828,6 +1829,7 @@ void LowerTypesPass::runOnOperation() {
18281829
CIRCT_DEBUG_SCOPED_PASS_LOGGER(this);
18291830

18301831
std::vector<FModuleLike> ops;
1832+
auto &instanceGraph = getAnalysis<InstanceGraph>();
18311833
// Symbol Table
18321834
auto &symTbl = getAnalysis<SymbolTable>();
18331835
// Cached attr
@@ -1836,7 +1838,15 @@ void LowerTypesPass::runOnOperation() {
18361838
DenseMap<FModuleLike, Convention> conventionTable;
18371839
auto circuit = getOperation();
18381840
for (auto module : circuit.getOps<FModuleLike>()) {
1839-
conventionTable.insert({module, module.getConvention()});
1841+
auto convention = module.getConvention();
1842+
// Instance choices select between modules with a shared port shape, so
1843+
// any module instantiated by one must use the scalarized convention.
1844+
if (llvm::any_of(instanceGraph.lookup(module)->uses(),
1845+
[](InstanceRecord *use) {
1846+
return use->getInstance<InstanceChoiceOp>();
1847+
}))
1848+
convention = Convention::Scalarized;
1849+
conventionTable.insert({module, convention});
18401850
ops.push_back(module);
18411851
}
18421852

test/Dialect/FIRRTL/errors.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3485,27 +3485,3 @@ firrtl.circuit "DomainCreateWrongFieldType" {
34853485
// expected-error @-1 {{use of value '%period' expects different type than prior uses: '!firrtl.integer' vs '!firrtl.string'}}
34863486
}
34873487
}
3488-
3489-
// -----
3490-
3491-
firrtl.circuit "InstanceChoiceAggregate" {
3492-
firrtl.option @Platform {
3493-
firrtl.option_case @FPGA
3494-
}
3495-
firrtl.module private @Target(
3496-
in %in: !firrtl.vector<uint<8>, 2>,
3497-
out %out: !firrtl.vector<uint<8>, 2>
3498-
) attributes {convention = #firrtl<convention internal>} { }
3499-
3500-
firrtl.module public @PublicTarget(
3501-
in %in: !firrtl.vector<uint<8>, 2>,
3502-
out %out: !firrtl.vector<uint<8>, 2>
3503-
) attributes {convention = #firrtl<convention scalarized>}{ }
3504-
3505-
firrtl.module @InstanceChoiceAggregate() {
3506-
// expected-error @below {{'firrtl.instance_choice' op all modules must have the same convention}}
3507-
%inst_in, %inst_out = firrtl.instance_choice inst sym @sym @Target alternatives @Platform {
3508-
@FPGA -> @PublicTarget
3509-
} (in in: !firrtl.vector<uint<8>, 2>, out out: !firrtl.vector<uint<8>, 2>)
3510-
}
3511-
}

test/Dialect/FIRRTL/lower-signatures.mlir

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,13 @@ firrtl.circuit "InstanceChoice" {
132132
firrtl.module @TargetWithDomain(
133133
in %D: !firrtl.domain<@ClockDomain()>,
134134
in %b: !firrtl.bundle<x: uint<1>, y: uint<2>> domains [%D]
135-
) attributes {convention = #firrtl<convention scalarized>} {
135+
) {
136136
}
137137

138-
firrtl.module @FPGATargetWithDomain(
139-
in %D: !firrtl.domain<@ClockDomain()>,
140-
in %b: !firrtl.bundle<x: uint<1>, y: uint<2>> domains [%D]
141-
) attributes {convention = #firrtl<convention scalarized>} {
142-
}
138+
firrtl.extmodule @FPGATargetWithDomain(
139+
in D: !firrtl.domain<@ClockDomain()>,
140+
in b: !firrtl.bundle<x: uint<1>, y: uint<2>> domains [D]
141+
)
143142

144143
firrtl.module @ASICTargetWithDomain(
145144
in %D: !firrtl.domain<@ClockDomain()>,

test/Dialect/FIRRTL/lower-types.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,10 @@ firrtl.circuit "InstanceChoiceTest" {
15681568
firrtl.matchingconnect %1, %0 : !firrtl.uint<8>
15691569
}
15701570

1571+
firrtl.extmodule private @Ext(
1572+
in in: !firrtl.bundle<a: uint<8>, b: uint<8>>,
1573+
out out: !firrtl.bundle<x: uint<8>, y: uint<8>>
1574+
)
15711575

15721576
// CHECK-LABEL: firrtl.module @InstanceChoiceTest
15731577
firrtl.module @InstanceChoiceTest(
@@ -1576,10 +1580,10 @@ firrtl.circuit "InstanceChoiceTest" {
15761580
out %out_x: !firrtl.uint<8>,
15771581
out %out_y: !firrtl.uint<8>
15781582
) {
1579-
// CHECK: %inst_in_a, %inst_in_b, %inst_out_x, %inst_out_y = firrtl.instance_choice inst @TargetModule alternatives @Platform
1583+
// CHECK: %inst_in_a, %inst_in_b, %inst_out_x, %inst_out_y = firrtl.instance_choice inst @Ext alternatives @Platform
15801584
// CHECK-SAME: @FPGA -> @TargetModule
15811585
// CHECK-SAME: (in in_a: !firrtl.uint<8>, in in_b: !firrtl.uint<8>, out out_x: !firrtl.uint<8>, out out_y: !firrtl.uint<8>)
1582-
%inst_in, %inst_out = firrtl.instance_choice inst @TargetModule alternatives @Platform {
1586+
%inst_in, %inst_out = firrtl.instance_choice inst @Ext alternatives @Platform {
15831587
@FPGA -> @TargetModule
15841588
} (in in: !firrtl.bundle<a: uint<8>, b: uint<8>>, out out: !firrtl.bundle<x: uint<8>, y: uint<8>>)
15851589

0 commit comments

Comments
 (0)