diff --git a/include/circt/Dialect/FIRRTL/FIRRTLUtils.h b/include/circt/Dialect/FIRRTL/FIRRTLUtils.h index 9bc3b8b42741..eb7324384392 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLUtils.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLUtils.h @@ -13,6 +13,7 @@ #ifndef CIRCT_DIALECT_FIRRTL_FIRRTLUTILS_H #define CIRCT_DIALECT_FIRRTL_FIRRTLUTILS_H +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/FIRRTL/FIRRTLOps.h" #include "mlir/IR/BuiltinOps.h" diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index ff2b292fce0e..98a78051a835 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -3165,7 +3165,6 @@ LogicalResult InstanceChoiceOp::verify() { LogicalResult InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto caseNames = getCaseNamesAttr(); - std::optional convention; for (auto moduleName : getModuleNamesAttr()) { auto moduleNameRef = cast(moduleName); if (failed(instance_like_impl::verifyReferencedModule(*this, symbolTable, @@ -3178,14 +3177,6 @@ InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { if (isa(referencedModule)) return emitOpError("intmodule must be instantiated with instance op, " "not via 'firrtl.instance_choice'"); - - if (!convention) { - convention = referencedModule.getConvention(); - continue; - } - - if (*convention != referencedModule.getConvention()) - return emitOpError("all modules must have the same convention"); } auto root = cast(caseNames[0]).getRootReference(); diff --git a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp index afd8283a7260..91d010f8be39 100644 --- a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/FIRRTL/FIRRTLUtils.h" +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/HW/InnerSymbolNamespace.h" #include "circt/Dialect/Seq/SeqTypes.h" diff --git a/lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp b/lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp index bdd29b90e406..aa82ad029f03 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp @@ -22,6 +22,7 @@ #include "circt/Support/InstanceGraphInterface.h" #include "mlir/IR/Threading.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "firrtl-lower-signatures" @@ -508,6 +509,7 @@ struct LowerSignaturesPass // This is the main entrypoint for the lowering pass. void LowerSignaturesPass::runOnOperation() { CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); + auto &instanceGraph = getAnalysis(); // Cached attr AttrCache cache(&getContext()); @@ -516,8 +518,15 @@ void LowerSignaturesPass::runOnOperation() { auto circuit = getOperation(); for (auto mod : circuit.getOps()) { - if (lowerModuleSignature(mod, mod.getConvention(), cache, - portMap[mod.getNameAttr()]) + auto convention = mod.getConvention(); + // Instance choices select between modules with a shared port shape, so + // any module instantiated by one must use the scalarized convention. + if (llvm::any_of(instanceGraph.lookup(mod)->uses(), + [](InstanceRecord *use) { + return use->getInstance(); + })) + convention = Convention::Scalarized; + if (lowerModuleSignature(mod, convention, cache, portMap[mod.getNameAttr()]) .failed()) return signalPassFailure(); } diff --git a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp index baee4fa35426..710f0882089b 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp @@ -41,6 +41,7 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "firrtl-lower-types" @@ -1828,6 +1829,7 @@ void LowerTypesPass::runOnOperation() { CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); std::vector ops; + auto &instanceGraph = getAnalysis(); // Symbol Table auto &symTbl = getAnalysis(); // Cached attr @@ -1836,7 +1838,15 @@ void LowerTypesPass::runOnOperation() { DenseMap conventionTable; auto circuit = getOperation(); for (auto module : circuit.getOps()) { - conventionTable.insert({module, module.getConvention()}); + auto convention = module.getConvention(); + // Instance choices select between modules with a shared port shape, so + // any module instantiated by one must use the scalarized convention. + if (llvm::any_of(instanceGraph.lookup(module)->uses(), + [](InstanceRecord *use) { + return use->getInstance(); + })) + convention = Convention::Scalarized; + conventionTable.insert({module, convention}); ops.push_back(module); } diff --git a/test/Dialect/FIRRTL/errors.mlir b/test/Dialect/FIRRTL/errors.mlir index 94eb9b101425..2cd22dfced69 100644 --- a/test/Dialect/FIRRTL/errors.mlir +++ b/test/Dialect/FIRRTL/errors.mlir @@ -3485,27 +3485,3 @@ firrtl.circuit "DomainCreateWrongFieldType" { // expected-error @-1 {{use of value '%period' expects different type than prior uses: '!firrtl.integer' vs '!firrtl.string'}} } } - -// ----- - -firrtl.circuit "InstanceChoiceAggregate" { - firrtl.option @Platform { - firrtl.option_case @FPGA - } - firrtl.module private @Target( - in %in: !firrtl.vector, 2>, - out %out: !firrtl.vector, 2> - ) attributes {convention = #firrtl} { } - - firrtl.module public @PublicTarget( - in %in: !firrtl.vector, 2>, - out %out: !firrtl.vector, 2> - ) attributes {convention = #firrtl}{ } - - firrtl.module @InstanceChoiceAggregate() { - // expected-error @below {{'firrtl.instance_choice' op all modules must have the same convention}} - %inst_in, %inst_out = firrtl.instance_choice inst sym @sym @Target alternatives @Platform { - @FPGA -> @PublicTarget - } (in in: !firrtl.vector, 2>, out out: !firrtl.vector, 2>) - } -} diff --git a/test/Dialect/FIRRTL/lower-signatures.mlir b/test/Dialect/FIRRTL/lower-signatures.mlir index cf21ef86d8d6..adc19c74f867 100644 --- a/test/Dialect/FIRRTL/lower-signatures.mlir +++ b/test/Dialect/FIRRTL/lower-signatures.mlir @@ -132,14 +132,13 @@ firrtl.circuit "InstanceChoice" { firrtl.module @TargetWithDomain( in %D: !firrtl.domain<@ClockDomain()>, in %b: !firrtl.bundle, y: uint<2>> domains [%D] - ) attributes {convention = #firrtl} { + ) { } - firrtl.module @FPGATargetWithDomain( - in %D: !firrtl.domain<@ClockDomain()>, - in %b: !firrtl.bundle, y: uint<2>> domains [%D] - ) attributes {convention = #firrtl} { - } + firrtl.extmodule @FPGATargetWithDomain( + in D: !firrtl.domain<@ClockDomain()>, + in b: !firrtl.bundle, y: uint<2>> domains [D] + ) firrtl.module @ASICTargetWithDomain( in %D: !firrtl.domain<@ClockDomain()>, diff --git a/test/Dialect/FIRRTL/lower-types.mlir b/test/Dialect/FIRRTL/lower-types.mlir index 2fd80cd2cdcf..036446c5a1b1 100644 --- a/test/Dialect/FIRRTL/lower-types.mlir +++ b/test/Dialect/FIRRTL/lower-types.mlir @@ -1568,6 +1568,10 @@ firrtl.circuit "InstanceChoiceTest" { firrtl.matchingconnect %1, %0 : !firrtl.uint<8> } + firrtl.extmodule private @Ext( + in in: !firrtl.bundle, b: uint<8>>, + out out: !firrtl.bundle, y: uint<8>> + ) // CHECK-LABEL: firrtl.module @InstanceChoiceTest firrtl.module @InstanceChoiceTest( @@ -1576,10 +1580,10 @@ firrtl.circuit "InstanceChoiceTest" { out %out_x: !firrtl.uint<8>, out %out_y: !firrtl.uint<8> ) { - // CHECK: %inst_in_a, %inst_in_b, %inst_out_x, %inst_out_y = firrtl.instance_choice inst @TargetModule alternatives @Platform + // CHECK: %inst_in_a, %inst_in_b, %inst_out_x, %inst_out_y = firrtl.instance_choice inst @Ext alternatives @Platform // CHECK-SAME: @FPGA -> @TargetModule // CHECK-SAME: (in in_a: !firrtl.uint<8>, in in_b: !firrtl.uint<8>, out out_x: !firrtl.uint<8>, out out_y: !firrtl.uint<8>) - %inst_in, %inst_out = firrtl.instance_choice inst @TargetModule alternatives @Platform { + %inst_in, %inst_out = firrtl.instance_choice inst @Ext alternatives @Platform { @FPGA -> @TargetModule } (in in: !firrtl.bundle, b: uint<8>>, out out: !firrtl.bundle, y: uint<8>>)