Skip to content

Commit bd6d6e3

Browse files
committed
[FIRRTL] Add force dedup reduction
Add a reduction to the FIRRTL dialect that takes MustDedup annotations and forcefully replaces all modules in the group with the first module. This is brittle, but works fairly well in practice.
1 parent 3703b21 commit bd6d6e3

File tree

5 files changed

+347
-33
lines changed

5 files changed

+347
-33
lines changed

include/circt/Reduce/Reduction.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,38 @@ struct Reduction {
104104
}
105105
};
106106

107+
/// A reduction pattern for a specific operation.
108+
///
109+
/// Only matches on operations of type `OpTy`, and calls corresponding match and
110+
/// rewrite functions with the operation cast to this type, for convenience.
107111
template <typename OpTy>
108112
struct OpReduction : public Reduction {
109-
uint64_t match(Operation *op) override {
113+
void matches(Operation *op,
114+
llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
110115
if (auto concreteOp = dyn_cast<OpTy>(op))
111-
return match(concreteOp);
112-
return 0;
116+
matches(concreteOp, addMatch);
113117
}
114-
LogicalResult rewrite(Operation *op) override {
115-
return rewrite(cast<OpTy>(op));
118+
LogicalResult rewriteMatches(Operation *op,
119+
ArrayRef<uint64_t> matches) override {
120+
return rewriteMatches(cast<OpTy>(op), matches);
116121
}
117122

118123
virtual uint64_t match(OpTy op) { return 1; }
119-
virtual LogicalResult rewrite(OpTy op) = 0;
124+
virtual void matches(OpTy op,
125+
llvm::function_ref<void(uint64_t, uint64_t)> addMatch) {
126+
addMatch(match(op), 0);
127+
}
128+
virtual LogicalResult rewrite(OpTy op) { return failure(); }
129+
virtual LogicalResult rewriteMatches(OpTy op, ArrayRef<uint64_t> matches) {
130+
assert(matches.size() == 1 && matches[0] == 0);
131+
return rewrite(op);
132+
}
133+
134+
private:
135+
/// Hide the base class match/rewrite functions to prevent compiler warnings
136+
/// about the `OpTy`-specific ones hiding the base class functions.
137+
using Reduction::match;
138+
using Reduction::rewrite;
120139
};
121140

122141
/// A reduction pattern that applies an `mlir::Pass`.

lib/Dialect/FIRRTL/FIRRTLReductions.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "circt/Dialect/FIRRTL/FIRRTLReductions.h"
10+
#include "circt/Dialect/FIRRTL/AnnotationDetails.h"
1011
#include "circt/Dialect/FIRRTL/CHIRRTLDialect.h"
12+
#include "circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
1113
#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1214
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
1315
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
@@ -25,6 +27,7 @@
2527

2628
using namespace mlir;
2729
using namespace circt;
30+
using namespace firrtl;
2831

2932
//===----------------------------------------------------------------------===//
3033
// Utilities
@@ -179,6 +182,8 @@ struct NLARemover {
179182
// Reduction patterns
180183
//===----------------------------------------------------------------------===//
181184

185+
namespace {
186+
182187
/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
183188
struct FIRRTLModuleExternalizer : public OpReduction<firrtl::FModuleOp> {
184189
void beforeReduction(mlir::ModuleOp op) override {
@@ -1118,6 +1123,182 @@ struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
11181123
bool isOneShot() const override { return true; }
11191124
};
11201125

1126+
/// A reduction pattern that handles MustDedup annotations by replacing all
1127+
/// module names in a dedup group with a single module name. This helps reduce
1128+
/// the IR by consolidating module references that are required to be identical.
1129+
///
1130+
/// The pattern works by:
1131+
/// 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1132+
/// 2. For each dedup group, using the first module as the canonical name
1133+
/// 3. Replacing all instance references to other modules in the group with
1134+
/// references to the canonical module
1135+
/// 4. Removing the non-canonical modules from the circuit
1136+
/// 5. Removing the processed MustDedup annotation
1137+
///
1138+
/// This reduction is particularly useful for reducing large circuits where
1139+
/// multiple modules are known to be identical but haven't been deduplicated
1140+
/// yet.
1141+
struct ForceDedup : public OpReduction<CircuitOp> {
1142+
void beforeReduction(mlir::ModuleOp op) override {
1143+
symbols.clear();
1144+
nlaRemover.clear();
1145+
}
1146+
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1147+
1148+
/// Collect all MustDedup annotations and create matches for each dedup group.
1149+
void matches(CircuitOp circuitOp,
1150+
llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
1151+
auto annotations = AnnotationSet(circuitOp);
1152+
for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1153+
if (!anno.isClass(mustDedupAnnoClass))
1154+
continue;
1155+
1156+
auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1157+
if (!modulesAttr)
1158+
continue;
1159+
1160+
// Each dedup group gets its own match with benefit proportional to group
1161+
// size.
1162+
uint64_t benefit = modulesAttr.size();
1163+
addMatch(benefit, annoIdx);
1164+
}
1165+
}
1166+
1167+
LogicalResult rewriteMatches(CircuitOp circuitOp,
1168+
ArrayRef<uint64_t> matches) override {
1169+
auto *context = circuitOp->getContext();
1170+
NLATable nlaTable(circuitOp);
1171+
hw::InnerSymbolTableCollection innerSymTables;
1172+
auto annotations = AnnotationSet(circuitOp);
1173+
SmallVector<Annotation> newAnnotations;
1174+
1175+
for (auto [annoIdx, anno] : llvm::enumerate(annotations)) {
1176+
// Check if this annotation was selected.
1177+
if (!llvm::is_contained(matches, annoIdx)) {
1178+
newAnnotations.push_back(anno);
1179+
continue;
1180+
}
1181+
auto modulesAttr = anno.getMember<ArrayAttr>("modules");
1182+
assert(anno.isClass(mustDedupAnnoClass) && modulesAttr &&
1183+
modulesAttr.size() >= 2);
1184+
1185+
// Extract module names from the dedup group.
1186+
SmallVector<StringAttr> moduleNames;
1187+
for (auto moduleRef : modulesAttr.getAsRange<StringAttr>()) {
1188+
// Parse "~CircuitName|ModuleName" format.
1189+
auto refStr = moduleRef.getValue();
1190+
auto pipePos = refStr.find('|');
1191+
if (pipePos != StringRef::npos && pipePos + 1 < refStr.size()) {
1192+
auto moduleName = refStr.substr(pipePos + 1);
1193+
moduleNames.push_back(StringAttr::get(context, moduleName));
1194+
}
1195+
}
1196+
1197+
// Simply drop the annotation if there's only one module.
1198+
if (moduleNames.size() < 2)
1199+
continue;
1200+
1201+
// Replace all instances and references to other modules with the
1202+
// first module.
1203+
replaceModuleReferences(circuitOp, moduleNames, nlaTable, innerSymTables);
1204+
nlaRemover.markNLAsInAnnotation(anno.getAttr());
1205+
}
1206+
if (newAnnotations.size() == annotations.size())
1207+
return failure();
1208+
1209+
// Update circuit annotations.
1210+
AnnotationSet newAnnoSet(newAnnotations, context);
1211+
newAnnoSet.applyToOperation(circuitOp);
1212+
return success();
1213+
}
1214+
1215+
std::string getName() const override { return "firrtl-force-dedup"; }
1216+
bool acceptSizeIncrease() const override { return true; }
1217+
1218+
private:
1219+
/// Replace all references to modules in the dedup group with the canonical
1220+
/// module name
1221+
void replaceModuleReferences(CircuitOp circuitOp,
1222+
ArrayRef<StringAttr> moduleNames,
1223+
NLATable &nlaTable,
1224+
hw::InnerSymbolTableCollection &innerSymTables) {
1225+
auto *tableOp = SymbolTable::getNearestSymbolTable(circuitOp);
1226+
auto &symbolTable = symbols.getSymbolTable(tableOp);
1227+
auto *context = circuitOp->getContext();
1228+
auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
1229+
1230+
// Collect the modules.
1231+
FModuleLike canonicalModule;
1232+
SmallVector<FModuleLike> modulesToReplace;
1233+
for (auto name : moduleNames) {
1234+
if (auto mod = symbolTable.lookup<FModuleLike>(name)) {
1235+
if (!canonicalModule)
1236+
canonicalModule = mod;
1237+
else
1238+
modulesToReplace.push_back(mod);
1239+
}
1240+
}
1241+
if (modulesToReplace.empty())
1242+
return;
1243+
1244+
// Replace all instance references.
1245+
auto canonicalName = canonicalModule.getModuleNameAttr();
1246+
auto canonicalRef = FlatSymbolRefAttr::get(canonicalName);
1247+
circuitOp.walk([&](InstanceOp instOp) {
1248+
auto moduleName = instOp.getModuleNameAttr().getAttr();
1249+
if (llvm::is_contained(moduleNames, moduleName) &&
1250+
moduleName != canonicalName) {
1251+
instOp.setModuleNameAttr(canonicalRef);
1252+
instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1253+
}
1254+
});
1255+
1256+
// Update NLAs to reference the canonical module instead of modules being
1257+
// removed using NLATable for better performance.
1258+
for (auto oldMod : modulesToReplace) {
1259+
SmallVector<hw::HierPathOp> nlaOps(
1260+
nlaTable.lookup(oldMod.getModuleNameAttr()));
1261+
for (auto nlaOp : nlaOps) {
1262+
nlaTable.erase(nlaOp);
1263+
StringAttr oldModName = oldMod.getModuleNameAttr();
1264+
StringAttr newModName = canonicalName;
1265+
SmallVector<Attribute, 4> newPath;
1266+
for (auto nameRef : nlaOp.getNamepath()) {
1267+
if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1268+
if (ref.getModule() == oldModName) {
1269+
auto oldInst = innerRefs.lookupOp<FInstanceLike>(ref);
1270+
ref = hw::InnerRefAttr::get(newModName, ref.getName());
1271+
auto newInst = innerRefs.lookupOp<FInstanceLike>(ref);
1272+
if (oldInst && newInst) {
1273+
oldModName = oldInst.getReferencedModuleNameAttr();
1274+
newModName = newInst.getReferencedModuleNameAttr();
1275+
}
1276+
}
1277+
newPath.push_back(ref);
1278+
} else if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldModName) {
1279+
newPath.push_back(FlatSymbolRefAttr::get(newModName));
1280+
} else {
1281+
newPath.push_back(nameRef);
1282+
}
1283+
}
1284+
nlaOp.setNamepathAttr(ArrayAttr::get(context, newPath));
1285+
nlaTable.addNLA(nlaOp);
1286+
}
1287+
}
1288+
1289+
// Mark NLAs in modules to be removed.
1290+
for (auto module : modulesToReplace) {
1291+
nlaRemover.markNLAsInOperation(module);
1292+
module->erase();
1293+
}
1294+
}
1295+
1296+
::detail::SymbolCache symbols;
1297+
NLARemover nlaRemover;
1298+
};
1299+
1300+
} // namespace
1301+
11211302
//===----------------------------------------------------------------------===//
11221303
// Reduction Registration
11231304
//===----------------------------------------------------------------------===//
@@ -1129,6 +1310,7 @@ void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
11291310
// prioritized). For example, things that can knock out entire modules while
11301311
// being cheap should be tried first (and thus have higher benefit), before
11311312
// trying to tweak operands of individual arithmetic ops.
1313+
patterns.add<ForceDedup, 31>();
11321314
patterns.add<PassReduction, 30>(
11331315
getContext(),
11341316
firrtl::createDropName({/*preserveMode=*/PreserveValues::None}), false,

lib/Reduce/GenericReductions.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ static std::unique_ptr<Pass> createSimpleCanonicalizerPass() {
8282

8383
void circt::populateGenericReducePatterns(MLIRContext *context,
8484
ReducePatternSet &patterns) {
85-
patterns.add<UnusedSymbolPruner, 40>();
86-
patterns.add<PassReduction, 3>(context, createCSEPass());
87-
patterns.add<PassReduction, 2>(context, createSimpleCanonicalizerPass());
85+
patterns.add<PassReduction, 103>(context, createSymbolDCEPass());
86+
patterns.add<PassReduction, 102>(context, createCSEPass());
87+
patterns.add<PassReduction, 101>(context, createSimpleCanonicalizerPass());
88+
patterns.add<UnusedSymbolPruner, 100>();
8889
patterns.add<OperationPruner, 1>();
8990
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// UNSUPPORTED: system-windows
2+
// See https://github.com/llvm/circt/issues/4129
3+
// RUN: circt-reduce %s --test /usr/bin/env --test-arg true --include firrtl-force-dedup | FileCheck %s
4+
5+
// Test that the MustDedup reducer can consolidate modules in a dedup group
6+
// by replacing all module names with a single canonical module name.
7+
8+
firrtl.circuit "MustDedup" attributes {annotations = [{
9+
class = "firrtl.transforms.MustDeduplicateAnnotation",
10+
modules = ["~MustDedup|Simple0", "~MustDedup|Simple1", "~MustDedup|Simple2"]
11+
}]} {
12+
13+
// CHECK: hw.hierpath private @nla [@MustDedup::@simple1, @Simple0]
14+
hw.hierpath private @nla [@MustDedup::@simple1, @Simple1]
15+
16+
// CHECK: firrtl.module private @Simple0
17+
firrtl.module private @Simple0() {
18+
%w = firrtl.wire : !firrtl.uint<1>
19+
}
20+
21+
// CHECK-NOT: firrtl.module private @Simple1
22+
firrtl.module private @Simple1() {
23+
%w = firrtl.wire : !firrtl.uint<1>
24+
}
25+
26+
// CHECK-NOT: firrtl.module private @Simple2
27+
firrtl.module private @Simple2() {
28+
%w = firrtl.wire : !firrtl.uint<1>
29+
}
30+
31+
// CHECK: firrtl.module @MustDedup
32+
firrtl.module @MustDedup() {
33+
// CHECK: firrtl.instance simple0 @Simple0
34+
firrtl.instance simple0 @Simple0()
35+
// CHECK: firrtl.instance simple1 sym @simple1 {annotations = [{circt.nonlocal = @nla, class = "test"}]} @Simple0()
36+
firrtl.instance simple1 sym @simple1 {annotations = [{circt.nonlocal = @nla, class = "test"}]} @Simple1()
37+
// CHECK: firrtl.instance simple2 @Simple0
38+
firrtl.instance simple2 @Simple2()
39+
}
40+
}
41+
42+
// Test with multiple NLAs referencing different modules in the dedup group
43+
firrtl.circuit "MultiNLA" attributes {annotations = [{
44+
class = "firrtl.transforms.MustDeduplicateAnnotation",
45+
modules = ["~MultiNLA|ModA", "~MultiNLA|ModB"]
46+
}]} {
47+
48+
// CHECK: hw.hierpath private @nla1 [@MultiNLA::@instA, @ModA]
49+
hw.hierpath private @nla1 [@MultiNLA::@instA, @ModA]
50+
// CHECK: hw.hierpath private @nla2 [@MultiNLA::@instB, @ModA]
51+
hw.hierpath private @nla2 [@MultiNLA::@instB, @ModB]
52+
53+
// CHECK: firrtl.module private @ModA
54+
firrtl.module private @ModA() {
55+
%w = firrtl.wire : !firrtl.uint<1>
56+
}
57+
58+
// CHECK-NOT: firrtl.module private @ModB
59+
firrtl.module private @ModB() {
60+
%w = firrtl.wire : !firrtl.uint<1>
61+
}
62+
63+
// CHECK: firrtl.module @MultiNLA
64+
firrtl.module @MultiNLA() {
65+
// CHECK: firrtl.instance instA sym @instA {annotations = [{circt.nonlocal = @nla1, class = "test1"}]} @ModA()
66+
firrtl.instance instA sym @instA {annotations = [{circt.nonlocal = @nla1, class = "test1"}]} @ModA()
67+
// CHECK: firrtl.instance instB sym @instB {annotations = [{circt.nonlocal = @nla2, class = "test2"}]} @ModA()
68+
firrtl.instance instB sym @instB {annotations = [{circt.nonlocal = @nla2, class = "test2"}]} @ModB()
69+
}
70+
}
71+
72+
// Test with multiple NLAs referencing different nested modules in the dedup group
73+
firrtl.circuit "MultiSubNLA" attributes {annotations = [{
74+
class = "firrtl.transforms.MustDeduplicateAnnotation",
75+
modules = ["~MultiSubNLA|ModA", "~MultiSubNLA|ModB"]
76+
}]} {
77+
78+
// CHECK: hw.hierpath private @nla1 [@MultiSubNLA::@instA, @ModA::@sub, @SubModA::@wire]
79+
hw.hierpath private @nla1 [@MultiSubNLA::@instA, @ModA::@sub, @SubModA::@wire]
80+
// CHECK: hw.hierpath private @nla2 [@MultiSubNLA::@instB, @ModA::@sub, @SubModA::@wire]
81+
hw.hierpath private @nla2 [@MultiSubNLA::@instB, @ModB::@sub, @SubModB::@wire]
82+
83+
// CHECK: firrtl.module private @ModA
84+
firrtl.module private @ModA() {
85+
firrtl.instance sub sym @sub @SubModA()
86+
}
87+
88+
// CHECK-NOT: firrtl.module private @ModB
89+
firrtl.module private @ModB() {
90+
firrtl.instance sub sym @sub @SubModB()
91+
}
92+
93+
// CHECK: firrtl.module private @SubModA
94+
firrtl.module private @SubModA() {
95+
%w = firrtl.wire sym @wire {annotations = [{circt.nonlocal = @nla1, class = "test1"}]} : !firrtl.uint<1>
96+
}
97+
98+
// CHECK: firrtl.module private @SubModB
99+
firrtl.module private @SubModB() {
100+
%w = firrtl.wire sym @wire {annotations = [{circt.nonlocal = @nla2, class = "test2"}]} : !firrtl.uint<1>
101+
}
102+
103+
// CHECK: firrtl.module @MultiSubNLA
104+
firrtl.module @MultiSubNLA() {
105+
// CHECK: firrtl.instance instA sym @instA @ModA()
106+
firrtl.instance instA sym @instA @ModA()
107+
// CHECK: firrtl.instance instB sym @instB @ModA()
108+
firrtl.instance instB sym @instB @ModB()
109+
}
110+
}

0 commit comments

Comments
 (0)