Skip to content

Commit 4c720dc

Browse files
[FIRRTL] Add reduction that moves MustDedup onto children (llvm#8969)
Add the `MustDedupChildren` reduction which takes must dedup annotations and replaces them with must dedup annotations on child modules.
1 parent fa523d8 commit 4c720dc

File tree

3 files changed

+240
-14
lines changed

3 files changed

+240
-14
lines changed

include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ struct TokenAnnoTarget {
4848
toVector(out);
4949
return std::string(out);
5050
}
51+
52+
/// Convert the annotation path to a StringAttr.
53+
StringAttr toStringAttr(MLIRContext *context) const {
54+
SmallString<32> out;
55+
toVector(out);
56+
return StringAttr::get(context, out);
57+
}
5158
};
5259

5360
// The potentially non-local resolved annotation.
@@ -141,9 +148,9 @@ struct AnnoTargetCache {
141148
AnnoTargetCache() = delete;
142149
AnnoTargetCache(const AnnoTargetCache &other) = default;
143150
AnnoTargetCache(AnnoTargetCache &&other)
144-
: targets(std::move(other.targets)){};
151+
: targets(std::move(other.targets)) {}
145152

146-
AnnoTargetCache(FModuleLike mod) { gatherTargets(mod); };
153+
AnnoTargetCache(FModuleLike mod) { gatherTargets(mod); }
147154

148155
/// Lookup the target for 'name', empty if not found.
149156
/// (check for validity using operator bool()).
@@ -359,7 +366,7 @@ struct ApplyState {
359366
IntegerAttr newID() {
360367
return IntegerAttr::get(IntegerType::get(circuit.getContext(), 64),
361368
annotationID++);
362-
};
369+
}
363370

364371
private:
365372
hw::InnerSymbolNamespaceCollection namespaces;

lib/Dialect/FIRRTL/FIRRTLReductions.cpp

Lines changed: 184 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "circt/Dialect/FIRRTL/FIRRTLReductions.h"
1010
#include "circt/Dialect/FIRRTL/AnnotationDetails.h"
1111
#include "circt/Dialect/FIRRTL/CHIRRTLDialect.h"
12+
#include "circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h"
1213
#include "circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
1314
#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1415
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
@@ -28,6 +29,8 @@
2829
using namespace mlir;
2930
using namespace circt;
3031
using namespace firrtl;
32+
using llvm::MapVector;
33+
using llvm::SmallSetVector;
3134

3235
//===----------------------------------------------------------------------===//
3336
// Utilities
@@ -1478,6 +1481,175 @@ struct ForceDedup : public OpReduction<CircuitOp> {
14781481
NLARemover nlaRemover;
14791482
};
14801483

1484+
/// A reduction pattern that moves `MustDedup` annotations from a module onto
1485+
/// its child modules. This pattern iterates over all MustDedup annotations,
1486+
/// collects all `FInstanceLike` ops in each module of the dedup group, and
1487+
/// creates new MustDedup annotations for corresponding instances across the
1488+
/// modules. Each set of corresponding instances becomes a separate match of the
1489+
/// reduction. The reduction also removes the original MustDedup annotation on
1490+
/// the parent module.
1491+
///
1492+
/// The pattern works by:
1493+
/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1494+
/// 2. For each dedup group, collecting all FInstanceLike operations in each
1495+
/// module
1496+
/// 3. Grouping corresponding instances across modules by their position/name
1497+
/// 4. Creating new MustDedup annotations for each group of corresponding
1498+
/// instances
1499+
/// 5. Removing the original MustDedup annotation from the circuit
1500+
struct MustDedupChildren : public OpReduction<CircuitOp> {
1501+
void beforeReduction(mlir::ModuleOp op) override {
1502+
symbols.clear();
1503+
nlaRemover.clear();
1504+
}
1505+
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1506+
1507+
/// Collect all MustDedup annotations and create matches for each instance
1508+
/// group.
1509+
void matches(CircuitOp circuitOp,
1510+
llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1511+
auto annotations = AnnotationSet(circuitOp);
1512+
uint64_t matchId = 0;
1513+
1514+
for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1515+
if (!anno.isClass(mustDedupAnnoClass))
1516+
continue;
1517+
1518+
auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1519+
if (!modulesAttr || modulesAttr.size() < 2)
1520+
continue;
1521+
1522+
// Process each group of corresponding instances
1523+
processInstanceGroups(
1524+
circuitOp, modulesAttr,
1525+
[&](ArrayRef<FInstanceLike>) { addMatch(1, matchId++); });
1526+
}
1527+
}
1528+
1529+
LogicalResult rewriteMatches(CircuitOp circuitOp,
1530+
ArrayRef<uint64_t> matches) override {
1531+
auto *context = circuitOp->getContext();
1532+
auto annotations = AnnotationSet(circuitOp);
1533+
SmallVector<Annotation> newAnnotations;
1534+
uint64_t matchId = 0;
1535+
1536+
for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1537+
if (!anno.isClass(mustDedupAnnoClass)) {
1538+
newAnnotations.push_back(anno);
1539+
continue;
1540+
}
1541+
1542+
auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1543+
if (!modulesAttr || modulesAttr.size() < 2) {
1544+
newAnnotations.push_back(anno);
1545+
continue;
1546+
}
1547+
1548+
// Track whether any matches were selected for this annotation
1549+
bool anyMatchSelected = false;
1550+
processInstanceGroups(
1551+
circuitOp, modulesAttr, [&](ArrayRef<FInstanceLike> instanceGroup) {
1552+
// Check if this instance group was selected
1553+
if (!llvm::is_contained(matches, matchId++))
1554+
return;
1555+
anyMatchSelected = true;
1556+
1557+
// Create the list of modules to put into this new annotation.
1558+
SmallSetVector<StringAttr, 4> moduleTargets;
1559+
for (auto instOp : instanceGroup) {
1560+
auto target = TokenAnnoTarget();
1561+
target.circuit = circuitOp.getName();
1562+
target.module = instOp.getReferencedModuleName();
1563+
moduleTargets.insert(target.toStringAttr(context));
1564+
}
1565+
if (moduleTargets.size() < 2)
1566+
return;
1567+
1568+
// Create a new MustDedup annotation for this list of modules.
1569+
SmallVector<NamedAttribute> newAnnoAttrs;
1570+
newAnnoAttrs.emplace_back(
1571+
StringAttr::get(context, "class"),
1572+
StringAttr::get(context, mustDedupAnnoClass));
1573+
newAnnoAttrs.emplace_back(
1574+
StringAttr::get(context, "modules"),
1575+
ArrayAttr::get(context,
1576+
SmallVector<Attribute>(moduleTargets.begin(),
1577+
moduleTargets.end())));
1578+
1579+
auto newAnnoDict = DictionaryAttr::get(context, newAnnoAttrs);
1580+
newAnnotations.emplace_back(newAnnoDict);
1581+
});
1582+
1583+
// If any matches were selected, mark the original annotation for removal
1584+
// since we're replacing it with new MustDedup annotations on the child
1585+
// modules. Otherwise keep the original annotation around.
1586+
if (anyMatchSelected)
1587+
nlaRemover.markNLAsInAnnotation(anno.getAttr());
1588+
else
1589+
newAnnotations.push_back(anno);
1590+
}
1591+
1592+
// Update circuit annotations
1593+
AnnotationSet newAnnoSet(newAnnotations, context);
1594+
newAnnoSet.applyToOperation(circuitOp);
1595+
return success();
1596+
}
1597+
1598+
std::string getName() const override { return "must-dedup-children"; }
1599+
bool acceptSizeIncrease() const override { return true; }
1600+
1601+
private:
1602+
/// Helper function to process groups of corresponding instances from a
1603+
/// MustDedup annotation. Calls the provided lambda for each group of
1604+
/// corresponding instances across the modules. Only calls the lambda if there
1605+
/// are at least 2 modules.
1606+
void processInstanceGroups(
1607+
CircuitOp circuitOp, ArrayAttr modulesAttr,
1608+
llvm::function_ref<void(ArrayRef<FInstanceLike>)> callback) {
1609+
auto &symbolTable = symbols.getSymbolTable(circuitOp);
1610+
1611+
// Extract module names and get the actual modules
1612+
SmallVector<FModuleLike> modules;
1613+
for (auto moduleRef : modulesAttr.getAsRange<StringAttr>())
1614+
if (auto target = tokenizePath(moduleRef))
1615+
if (auto mod = symbolTable.lookup<FModuleLike>(target->module))
1616+
modules.push_back(mod);
1617+
1618+
// Need at least 2 modules for deduplication
1619+
if (modules.size() < 2)
1620+
return;
1621+
1622+
// Collect all FInstanceLike operations from each module and group them by
1623+
// name. Instance names are a good key for matching instances across
1624+
// modules. But they may not be unique, so we need to be careful to only
1625+
// match up instances that are uniquely named within every module.
1626+
struct InstanceGroup {
1627+
SmallVector<FInstanceLike> instances;
1628+
bool nameIsUnique = true;
1629+
};
1630+
MapVector<StringAttr, InstanceGroup> instanceGroups;
1631+
for (auto module : modules) {
1632+
SmallDenseMap<StringAttr, unsigned> nameCounts;
1633+
module.walk([&](FInstanceLike instOp) {
1634+
auto name = instOp.getInstanceNameAttr();
1635+
auto &group = instanceGroups[name];
1636+
if (nameCounts[name]++ > 1)
1637+
group.nameIsUnique = false;
1638+
group.instances.push_back(instOp);
1639+
});
1640+
}
1641+
1642+
// Call the callback for each group of instances that are uniquely named and
1643+
// consist of at least 2 instances.
1644+
for (auto &[name, group] : instanceGroups)
1645+
if (group.nameIsUnique && group.instances.size() >= 2)
1646+
callback(group.instances);
1647+
}
1648+
1649+
::detail::SymbolCache symbols;
1650+
NLARemover nlaRemover;
1651+
};
1652+
14811653
} // namespace
14821654

14831655
//===----------------------------------------------------------------------===//
@@ -1493,25 +1665,26 @@ void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
14931665
// trying to tweak operands of individual arithmetic ops.
14941666
patterns.add<ModuleSwapper, 32>();
14951667
patterns.add<ForceDedup, 31>();
1496-
patterns.add<PassReduction, 30>(
1668+
patterns.add<MustDedupChildren, 30>();
1669+
patterns.add<PassReduction, 29>(
14971670
getContext(),
14981671
firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,
14991672
true);
1500-
patterns.add<PassReduction, 29>(getContext(),
1673+
patterns.add<PassReduction, 28>(getContext(),
15011674
firrtl::createLowerCHIRRTLPass(), true, true);
1502-
patterns.add<PassReduction, 28>(getContext(), firrtl::createInferWidths(),
1675+
patterns.add<PassReduction, 27>(getContext(), firrtl::createInferWidths(),
15031676
true, true);
1504-
patterns.add<PassReduction, 27>(getContext(), firrtl::createInferResets(),
1677+
patterns.add<PassReduction, 26>(getContext(), firrtl::createInferResets(),
15051678
true, true);
1506-
patterns.add<FIRRTLModuleExternalizer, 26>();
1507-
patterns.add<InstanceStubber, 25>();
1508-
patterns.add<MemoryStubber, 24>();
1509-
patterns.add<EagerInliner, 23>();
1510-
patterns.add<PassReduction, 22>(getContext(),
1679+
patterns.add<FIRRTLModuleExternalizer, 25>();
1680+
patterns.add<InstanceStubber, 24>();
1681+
patterns.add<MemoryStubber, 23>();
1682+
patterns.add<EagerInliner, 22>();
1683+
patterns.add<PassReduction, 21>(getContext(),
15111684
firrtl::createLowerFIRRTLTypes(), true, true);
1512-
patterns.add<PassReduction, 21>(getContext(), firrtl::createExpandWhens(),
1685+
patterns.add<PassReduction, 20>(getContext(), firrtl::createExpandWhens(),
15131686
true, true);
1514-
patterns.add<PassReduction, 20>(getContext(), firrtl::createInliner());
1687+
patterns.add<PassReduction, 19>(getContext(), firrtl::createInliner());
15151688
patterns.add<PassReduction, 18>(getContext(), firrtl::createIMConstProp());
15161689
patterns.add<PassReduction, 17>(
15171690
getContext(),
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// UNSUPPORTED: system-windows
2+
// See https://github.com/llvm/circt/issues/4129
3+
// RUN: circt-reduce %s --test /usr/bin/env --test-arg true --keep-best=0 --include must-dedup-children | FileCheck %s
4+
5+
// Test that MustDedup annotations are moved from parent modules to their child modules
6+
7+
// CHECK: firrtl.circuit "Top" attributes {annotations = [
8+
// CHECK-DAG: {class = "firrtl.transforms.MustDeduplicateAnnotation", modules = ["~Top|ChildA", "~Top|ChildB"]}
9+
// CHECK-DAG: {class = "firrtl.transforms.MustDeduplicateAnnotation", modules = ["~Top|ChildC", "~Top|ChildD"]}
10+
// CHECK: ]}
11+
12+
firrtl.circuit "Top" attributes {annotations = [{
13+
class = "firrtl.transforms.MustDeduplicateAnnotation",
14+
modules = ["~Top|ParentA", "~Top|ParentB"]
15+
}]} {
16+
firrtl.module @Top() {
17+
firrtl.instance parentA @ParentA()
18+
firrtl.instance parentB @ParentB()
19+
}
20+
21+
firrtl.module private @ParentA() {
22+
firrtl.instance child1 @ChildA()
23+
firrtl.instance child2 @ChildC()
24+
}
25+
26+
firrtl.module private @ParentB() {
27+
firrtl.instance child1 @ChildB()
28+
firrtl.instance child2 @ChildD()
29+
}
30+
31+
firrtl.module private @ChildA() {
32+
%w = firrtl.wire : !firrtl.uint<8>
33+
}
34+
35+
firrtl.module private @ChildB() {
36+
%w = firrtl.wire : !firrtl.uint<8>
37+
}
38+
39+
firrtl.module private @ChildC() {
40+
%w = firrtl.wire : !firrtl.uint<8>
41+
}
42+
43+
firrtl.module private @ChildD() {
44+
%w = firrtl.wire : !firrtl.uint<8>
45+
}
46+
}

0 commit comments

Comments
 (0)