99#include " circt/Dialect/FIRRTL/FIRRTLReductions.h"
1010#include " circt/Dialect/FIRRTL/AnnotationDetails.h"
1111#include " circt/Dialect/FIRRTL/CHIRRTLDialect.h"
12+ #include " circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h"
1213#include " circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
1314#include " circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1415#include " circt/Dialect/FIRRTL/FIRRTLOps.h"
2829using namespace mlir ;
2930using namespace circt ;
3031using namespace firrtl ;
32+ using llvm::MapVector;
33+ using llvm::SmallSetVector;
3134
3235// ===----------------------------------------------------------------------===//
3336// Utilities
@@ -1478,6 +1481,175 @@ struct ForceDedup : public OpReduction<CircuitOp> {
14781481 NLARemover nlaRemover;
14791482};
14801483
1484+ // / A reduction pattern that moves `MustDedup` annotations from a module onto
1485+ // / its child modules. This pattern iterates over all MustDedup annotations,
1486+ // / collects all `FInstanceLike` ops in each module of the dedup group, and
1487+ // / creates new MustDedup annotations for corresponding instances across the
1488+ // / modules. Each set of corresponding instances becomes a separate match of the
1489+ // / reduction. The reduction also removes the original MustDedup annotation on
1490+ // / the parent module.
1491+ // /
1492+ // / The pattern works by:
1493+ // / 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1494+ // / 2. For each dedup group, collecting all FInstanceLike operations in each
1495+ // / module
1496+ // / 3. Grouping corresponding instances across modules by their position/name
1497+ // / 4. Creating new MustDedup annotations for each group of corresponding
1498+ // / instances
1499+ // / 5. Removing the original MustDedup annotation from the circuit
1500+ struct MustDedupChildren : public OpReduction <CircuitOp> {
1501+ void beforeReduction (mlir::ModuleOp op) override {
1502+ symbols.clear ();
1503+ nlaRemover.clear ();
1504+ }
1505+ void afterReduction (mlir::ModuleOp op) override { nlaRemover.remove (op); }
1506+
1507+ // / Collect all MustDedup annotations and create matches for each instance
1508+ // / group.
1509+ void matches (CircuitOp circuitOp,
1510+ llvm::function_ref<void (uint64_t , uint64_t )> addMatch) override {
1511+ auto annotations = AnnotationSet (circuitOp);
1512+ uint64_t matchId = 0 ;
1513+
1514+ for (auto [annoIdx, anno] : llvm::enumerate (annotations)) {
1515+ if (!anno.isClass (mustDedupAnnoClass))
1516+ continue ;
1517+
1518+ auto modulesAttr = anno.getMember <ArrayAttr>(" modules" );
1519+ if (!modulesAttr || modulesAttr.size () < 2 )
1520+ continue ;
1521+
1522+ // Process each group of corresponding instances
1523+ processInstanceGroups (
1524+ circuitOp, modulesAttr,
1525+ [&](ArrayRef<FInstanceLike>) { addMatch (1 , matchId++); });
1526+ }
1527+ }
1528+
1529+ LogicalResult rewriteMatches (CircuitOp circuitOp,
1530+ ArrayRef<uint64_t > matches) override {
1531+ auto *context = circuitOp->getContext ();
1532+ auto annotations = AnnotationSet (circuitOp);
1533+ SmallVector<Annotation> newAnnotations;
1534+ uint64_t matchId = 0 ;
1535+
1536+ for (auto [annoIdx, anno] : llvm::enumerate (annotations)) {
1537+ if (!anno.isClass (mustDedupAnnoClass)) {
1538+ newAnnotations.push_back (anno);
1539+ continue ;
1540+ }
1541+
1542+ auto modulesAttr = anno.getMember <ArrayAttr>(" modules" );
1543+ if (!modulesAttr || modulesAttr.size () < 2 ) {
1544+ newAnnotations.push_back (anno);
1545+ continue ;
1546+ }
1547+
1548+ // Track whether any matches were selected for this annotation
1549+ bool anyMatchSelected = false ;
1550+ processInstanceGroups (
1551+ circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1552+ // Check if this instance group was selected
1553+ if (!llvm::is_contained (matches, matchId++))
1554+ return ;
1555+ anyMatchSelected = true ;
1556+
1557+ // Create the list of modules to put into this new annotation.
1558+ SmallSetVector<StringAttr, 4 > moduleTargets;
1559+ for (auto instOp : instanceGroup) {
1560+ auto target = TokenAnnoTarget ();
1561+ target.circuit = circuitOp.getName ();
1562+ target.module = instOp.getReferencedModuleName ();
1563+ moduleTargets.insert (target.toStringAttr (context));
1564+ }
1565+ if (moduleTargets.size () < 2 )
1566+ return ;
1567+
1568+ // Create a new MustDedup annotation for this list of modules.
1569+ SmallVector<NamedAttribute> newAnnoAttrs;
1570+ newAnnoAttrs.emplace_back (
1571+ StringAttr::get (context, " class" ),
1572+ StringAttr::get (context, mustDedupAnnoClass));
1573+ newAnnoAttrs.emplace_back (
1574+ StringAttr::get (context, " modules" ),
1575+ ArrayAttr::get (context,
1576+ SmallVector<Attribute>(moduleTargets.begin (),
1577+ moduleTargets.end ())));
1578+
1579+ auto newAnnoDict = DictionaryAttr::get (context, newAnnoAttrs);
1580+ newAnnotations.emplace_back (newAnnoDict);
1581+ });
1582+
1583+ // If any matches were selected, mark the original annotation for removal
1584+ // since we're replacing it with new MustDedup annotations on the child
1585+ // modules. Otherwise keep the original annotation around.
1586+ if (anyMatchSelected)
1587+ nlaRemover.markNLAsInAnnotation (anno.getAttr ());
1588+ else
1589+ newAnnotations.push_back (anno);
1590+ }
1591+
1592+ // Update circuit annotations
1593+ AnnotationSet newAnnoSet (newAnnotations, context);
1594+ newAnnoSet.applyToOperation (circuitOp);
1595+ return success ();
1596+ }
1597+
1598+ std::string getName () const override { return " must-dedup-children" ; }
1599+ bool acceptSizeIncrease () const override { return true ; }
1600+
1601+ private:
1602+ // / Helper function to process groups of corresponding instances from a
1603+ // / MustDedup annotation. Calls the provided lambda for each group of
1604+ // / corresponding instances across the modules. Only calls the lambda if there
1605+ // / are at least 2 modules.
1606+ void processInstanceGroups (
1607+ CircuitOp circuitOp, ArrayAttr modulesAttr,
1608+ llvm::function_ref<void (ArrayRef<FInstanceLike>)> callback) {
1609+ auto &symbolTable = symbols.getSymbolTable (circuitOp);
1610+
1611+ // Extract module names and get the actual modules
1612+ SmallVector<FModuleLike> modules;
1613+ for (auto moduleRef : modulesAttr.getAsRange <StringAttr>())
1614+ if (auto target = tokenizePath (moduleRef))
1615+ if (auto mod = symbolTable.lookup <FModuleLike>(target->module ))
1616+ modules.push_back (mod);
1617+
1618+ // Need at least 2 modules for deduplication
1619+ if (modules.size () < 2 )
1620+ return ;
1621+
1622+ // Collect all FInstanceLike operations from each module and group them by
1623+ // name. Instance names are a good key for matching instances across
1624+ // modules. But they may not be unique, so we need to be careful to only
1625+ // match up instances that are uniquely named within every module.
1626+ struct InstanceGroup {
1627+ SmallVector<FInstanceLike> instances;
1628+ bool nameIsUnique = true ;
1629+ };
1630+ MapVector<StringAttr, InstanceGroup> instanceGroups;
1631+ for (auto module : modules) {
1632+ SmallDenseMap<StringAttr, unsigned > nameCounts;
1633+ module .walk ([&](FInstanceLike instOp) {
1634+ auto name = instOp.getInstanceNameAttr ();
1635+ auto &group = instanceGroups[name];
1636+ if (nameCounts[name]++ > 1 )
1637+ group.nameIsUnique = false ;
1638+ group.instances .push_back (instOp);
1639+ });
1640+ }
1641+
1642+ // Call the callback for each group of instances that are uniquely named and
1643+ // consist of at least 2 instances.
1644+ for (auto &[name, group] : instanceGroups)
1645+ if (group.nameIsUnique && group.instances .size () >= 2 )
1646+ callback (group.instances );
1647+ }
1648+
1649+ ::detail::SymbolCache symbols;
1650+ NLARemover nlaRemover;
1651+ };
1652+
14811653} // namespace
14821654
14831655// ===----------------------------------------------------------------------===//
@@ -1493,25 +1665,26 @@ void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
14931665 // trying to tweak operands of individual arithmetic ops.
14941666 patterns.add <ModuleSwapper, 32 >();
14951667 patterns.add <ForceDedup, 31 >();
1496- patterns.add <PassReduction, 30 >(
1668+ patterns.add <MustDedupChildren, 30 >();
1669+ patterns.add <PassReduction, 29 >(
14971670 getContext (),
14981671 firrtl::createDropName ({/* preserveMode=*/ PreserveValues::None}), false ,
14991672 true );
1500- patterns.add <PassReduction, 29 >(getContext (),
1673+ patterns.add <PassReduction, 28 >(getContext (),
15011674 firrtl::createLowerCHIRRTLPass (), true , true );
1502- patterns.add <PassReduction, 28 >(getContext (), firrtl::createInferWidths (),
1675+ patterns.add <PassReduction, 27 >(getContext (), firrtl::createInferWidths (),
15031676 true , true );
1504- patterns.add <PassReduction, 27 >(getContext (), firrtl::createInferResets (),
1677+ patterns.add <PassReduction, 26 >(getContext (), firrtl::createInferResets (),
15051678 true , true );
1506- patterns.add <FIRRTLModuleExternalizer, 26 >();
1507- patterns.add <InstanceStubber, 25 >();
1508- patterns.add <MemoryStubber, 24 >();
1509- patterns.add <EagerInliner, 23 >();
1510- patterns.add <PassReduction, 22 >(getContext (),
1679+ patterns.add <FIRRTLModuleExternalizer, 25 >();
1680+ patterns.add <InstanceStubber, 24 >();
1681+ patterns.add <MemoryStubber, 23 >();
1682+ patterns.add <EagerInliner, 22 >();
1683+ patterns.add <PassReduction, 21 >(getContext (),
15111684 firrtl::createLowerFIRRTLTypes (), true , true );
1512- patterns.add <PassReduction, 21 >(getContext (), firrtl::createExpandWhens (),
1685+ patterns.add <PassReduction, 20 >(getContext (), firrtl::createExpandWhens (),
15131686 true , true );
1514- patterns.add <PassReduction, 20 >(getContext (), firrtl::createInliner ());
1687+ patterns.add <PassReduction, 19 >(getContext (), firrtl::createInliner ());
15151688 patterns.add <PassReduction, 18 >(getContext (), firrtl::createIMConstProp ());
15161689 patterns.add <PassReduction, 17 >(
15171690 getContext (),
0 commit comments