Skip to content

Commit bafe0f4

Browse files
committed
[FIRRTL] Add module swapping reduction
Add a reduction pattern that changes instances to point to a smaller module with the same port signature. This can be very effective under MustDedup and in situations where a module cannot be converted into an extmodule, but the exact function of the instantiated module is not relevant.
1 parent bd6d6e3 commit bafe0f4

File tree

3 files changed

+363
-0
lines changed

3 files changed

+363
-0
lines changed

lib/Dialect/FIRRTL/FIRRTLReductions.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,149 @@ struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
11231123
bool isOneShot() const override { return true; }
11241124
};
11251125

1126+
/// A reduction pattern that groups modules by their port signature (types and
1127+
/// directions) and replaces instances with the smallest module in each group.
1128+
/// This helps reduce the IR by consolidating functionally equivalent modules
1129+
/// based on their interface.
1130+
///
1131+
/// The pattern works by:
1132+
/// 1. Grouping all modules by their port signature (port types and directions)
1133+
/// 2. For each group with multiple modules, finding the smallest module using
1134+
/// the module size cache
1135+
/// 3. Replacing all instances of larger modules with instances of the smallest
1136+
/// module in the same group
1137+
/// 4. Removing the larger modules from the circuit
1138+
///
1139+
/// This reduction is useful for reducing circuits where multiple modules have
1140+
/// the same interface but different implementations, allowing the reducer to
1141+
/// try the smallest implementation first.
1142+
struct ModuleSwapper : public OpReduction<InstanceOp> {
1143+
// Per-circuit state containing all the information needed for module swapping
1144+
using PortSignature = SmallVector<std::pair<Type, Direction>>;
1145+
struct CircuitState {
1146+
DenseMap<PortSignature, SmallVector<FModuleLike, 4>> moduleTypeGroups;
1147+
DenseMap<StringAttr, FModuleLike> instanceToCanonicalModule;
1148+
std::unique_ptr<NLATable> nlaTable;
1149+
};
1150+
1151+
void beforeReduction(mlir::ModuleOp op) override {
1152+
symbols.clear();
1153+
nlaRemover.clear();
1154+
moduleSizes.clear();
1155+
circuitStates.clear();
1156+
1157+
// Collect module type groups and NLA tables for all circuits up front
1158+
op.walk([&](CircuitOp circuitOp) {
1159+
auto &state = circuitStates[circuitOp];
1160+
state.nlaTable = std::make_unique<NLATable>(circuitOp);
1161+
buildModuleTypeGroups(circuitOp, state);
1162+
});
1163+
}
1164+
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
1165+
1166+
/// Create a vector of port type-direction pairs for the given FIRRTL module.
1167+
/// This ignores port names, allowing modules with the same port types and
1168+
/// directions but different port names to be considered equivalent for
1169+
/// swapping.
1170+
PortSignature getModulePortSignature(FModuleLike module) {
1171+
PortSignature signature;
1172+
for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
1173+
signature.emplace_back(module.getPortType(i), module.getPortDirection(i));
1174+
}
1175+
return signature;
1176+
}
1177+
1178+
/// Group modules by their port signature and find the smallest in each group.
1179+
void buildModuleTypeGroups(CircuitOp circuitOp, CircuitState &state) {
1180+
// Group modules by their port signature
1181+
for (auto module : circuitOp.getBodyBlock()->getOps<FModuleLike>()) {
1182+
auto signature = getModulePortSignature(module);
1183+
state.moduleTypeGroups[signature].push_back(module);
1184+
}
1185+
1186+
// For each group, find the smallest module
1187+
for (auto &[signature, modules] : state.moduleTypeGroups) {
1188+
if (modules.size() <= 1)
1189+
continue;
1190+
1191+
FModuleLike smallestModule = nullptr;
1192+
uint64_t smallestSize = UINT64_MAX;
1193+
1194+
for (auto module : modules) {
1195+
uint64_t size = moduleSizes.getModuleSize(module, symbols);
1196+
if (size < smallestSize) {
1197+
smallestSize = size;
1198+
smallestModule = module;
1199+
}
1200+
}
1201+
1202+
// Map all modules in this group to the smallest one
1203+
for (auto module : modules) {
1204+
if (module != smallestModule) {
1205+
state.instanceToCanonicalModule[module.getModuleNameAttr()] =
1206+
smallestModule;
1207+
}
1208+
}
1209+
}
1210+
}
1211+
1212+
uint64_t match(InstanceOp instOp) override {
1213+
// Get the circuit this instance belongs to
1214+
auto circuitOp = instOp->getParentOfType<CircuitOp>();
1215+
assert(circuitOp);
1216+
const auto &state = circuitStates.at(circuitOp);
1217+
1218+
// Skip instances that participate in any NLAs
1219+
DenseSet<hw::HierPathOp> nlas;
1220+
state.nlaTable->getInstanceNLAs(instOp, nlas);
1221+
if (!nlas.empty())
1222+
return 0;
1223+
1224+
// Check if this instance can be redirected to a smaller module
1225+
auto moduleName = instOp.getModuleNameAttr().getAttr();
1226+
auto canonicalModule = state.instanceToCanonicalModule.lookup(moduleName);
1227+
if (!canonicalModule)
1228+
return 0;
1229+
1230+
// Benefit is the size difference
1231+
auto currentModule = cast<FModuleLike>(
1232+
instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
1233+
uint64_t currentSize = moduleSizes.getModuleSize(currentModule, symbols);
1234+
uint64_t canonicalSize =
1235+
moduleSizes.getModuleSize(canonicalModule, symbols);
1236+
return currentSize > canonicalSize ? currentSize - canonicalSize : 1;
1237+
}
1238+
1239+
LogicalResult rewrite(InstanceOp instOp) override {
1240+
// Get the circuit this instance belongs to
1241+
auto circuitOp = instOp->getParentOfType<CircuitOp>();
1242+
assert(circuitOp);
1243+
const auto &state = circuitStates.at(circuitOp);
1244+
1245+
// Replace the instantiated module with the canonical module.
1246+
auto canonicalModule = state.instanceToCanonicalModule.at(
1247+
instOp.getModuleNameAttr().getAttr());
1248+
auto canonicalName = canonicalModule.getModuleNameAttr();
1249+
instOp.setModuleNameAttr(FlatSymbolRefAttr::get(canonicalName));
1250+
1251+
// Update port names to match the canonical module
1252+
instOp.setPortNamesAttr(canonicalModule.getPortNamesAttr());
1253+
1254+
return success();
1255+
}
1256+
1257+
std::string getName() const override { return "firrtl-module-swapper"; }
1258+
bool acceptSizeIncrease() const override { return true; }
1259+
1260+
private:
1261+
::detail::SymbolCache symbols;
1262+
NLARemover nlaRemover;
1263+
ModuleSizeCache moduleSizes;
1264+
1265+
// Per-circuit state containing all module swapping information
1266+
DenseMap<CircuitOp, CircuitState> circuitStates;
1267+
};
1268+
11261269
/// A reduction pattern that handles MustDedup annotations by replacing all
11271270
/// module names in a dedup group with a single module name. This helps reduce
11281271
/// the IR by consolidating module references that are required to be identical.
@@ -1310,6 +1453,7 @@ void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
13101453
// prioritized). For example, things that can knock out entire modules while
13111454
// being cheap should be tried first (and thus have higher benefit), before
13121455
// trying to tweak operands of individual arithmetic ops.
1456+
patterns.add<ModuleSwapper, 32>();
13131457
patterns.add<ForceDedup, 31>();
13141458
patterns.add<PassReduction, 30>(
13151459
getContext(),
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// UNSUPPORTED: system-windows
2+
// See https://github.com/llvm/circt/issues/4129
3+
// RUN: circt-reduce --test /usr/bin/env --test-arg true --include firrtl-module-swapper --max-chunks=1 %s | FileCheck %s
4+
5+
// Test that the ModuleSwapper reducer can replace instances of larger modules
6+
// with instances of smaller modules that have the same port signature.
7+
8+
// CHECK-LABEL: firrtl.circuit "ModuleSwapperTest"
9+
firrtl.circuit "ModuleSwapperTest" {
10+
// CHECK: firrtl.module private @SmallModule
11+
firrtl.module private @SmallModule(in %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>) {
12+
// Small module with minimal implementation
13+
firrtl.connect %b, %a : !firrtl.uint<1>, !firrtl.uint<1>
14+
}
15+
16+
// CHECK: firrtl.module private @LargeModule
17+
firrtl.module private @LargeModule(in %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>) {
18+
// Large module with more complex implementation (same interface as SmallModule)
19+
%wire1 = firrtl.wire : !firrtl.uint<1>
20+
%wire2 = firrtl.wire : !firrtl.uint<1>
21+
%wire3 = firrtl.wire : !firrtl.uint<1>
22+
firrtl.connect %wire1, %a : !firrtl.uint<1>, !firrtl.uint<1>
23+
firrtl.connect %wire2, %wire1 : !firrtl.uint<1>, !firrtl.uint<1>
24+
firrtl.connect %wire3, %wire2 : !firrtl.uint<1>, !firrtl.uint<1>
25+
firrtl.connect %b, %wire3 : !firrtl.uint<1>, !firrtl.uint<1>
26+
}
27+
28+
// CHECK: firrtl.module private @AnotherLargeModule
29+
firrtl.module private @AnotherLargeModule(in %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>) {
30+
// Another large module with same interface but different implementation
31+
%wire1 = firrtl.wire : !firrtl.uint<1>
32+
%wire2 = firrtl.wire : !firrtl.uint<1>
33+
%wire3 = firrtl.wire : !firrtl.uint<1>
34+
%wire4 = firrtl.wire : !firrtl.uint<1>
35+
firrtl.connect %wire1, %a : !firrtl.uint<1>, !firrtl.uint<1>
36+
firrtl.connect %wire2, %wire1 : !firrtl.uint<1>, !firrtl.uint<1>
37+
firrtl.connect %wire3, %wire2 : !firrtl.uint<1>, !firrtl.uint<1>
38+
firrtl.connect %wire4, %wire3 : !firrtl.uint<1>, !firrtl.uint<1>
39+
firrtl.connect %b, %wire4 : !firrtl.uint<1>, !firrtl.uint<1>
40+
}
41+
42+
// Module with different interface - should not be affected
43+
// CHECK: firrtl.module private @DifferentInterface
44+
firrtl.module private @DifferentInterface(in %x: !firrtl.uint<2>, out %y: !firrtl.uint<2>) {
45+
%wire = firrtl.wire : !firrtl.uint<2>
46+
firrtl.connect %wire, %x : !firrtl.uint<2>, !firrtl.uint<2>
47+
firrtl.connect %y, %wire : !firrtl.uint<2>, !firrtl.uint<2>
48+
}
49+
50+
// CHECK: firrtl.module @ModuleSwapperTest
51+
firrtl.module @ModuleSwapperTest(in %clk: !firrtl.clock, in %input: !firrtl.uint<1>, out %output1: !firrtl.uint<1>, out %output2: !firrtl.uint<1>, out %output3: !firrtl.uint<1>, out %output4: !firrtl.uint<2>) {
52+
// CHECK: firrtl.instance small @SmallModule
53+
%small_a, %small_b = firrtl.instance small @SmallModule(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
54+
55+
// CHECK: firrtl.instance large @SmallModule
56+
%large_a, %large_b = firrtl.instance large @LargeModule(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
57+
58+
// CHECK: firrtl.instance another @SmallModule
59+
%another_a, %another_b = firrtl.instance another @AnotherLargeModule(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
60+
61+
// This should remain unchanged as it has a different interface
62+
// CHECK: firrtl.instance diff @DifferentInterface
63+
%diff_x, %diff_y = firrtl.instance diff @DifferentInterface(in x: !firrtl.uint<2>, out y: !firrtl.uint<2>)
64+
65+
// Connect inputs
66+
firrtl.connect %small_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
67+
firrtl.connect %large_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
68+
firrtl.connect %another_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
69+
70+
%input_ext = firrtl.pad %input, 2 : (!firrtl.uint<1>) -> !firrtl.uint<2>
71+
firrtl.connect %diff_x, %input_ext : !firrtl.uint<2>, !firrtl.uint<2>
72+
73+
// Connect outputs
74+
firrtl.connect %output1, %small_b : !firrtl.uint<1>, !firrtl.uint<1>
75+
firrtl.connect %output2, %large_b : !firrtl.uint<1>, !firrtl.uint<1>
76+
firrtl.connect %output3, %another_b : !firrtl.uint<1>, !firrtl.uint<1>
77+
firrtl.connect %output4, %diff_y : !firrtl.uint<2>, !firrtl.uint<2>
78+
}
79+
}
80+
81+
// Test with modules that have different port directions but same types
82+
// CHECK-LABEL: firrtl.circuit "DirectionTest"
83+
firrtl.circuit "DirectionTest" {
84+
// CHECK: firrtl.module private @SimpleInOut
85+
firrtl.module private @SimpleInOut(in %in: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
86+
firrtl.connect %out, %in : !firrtl.uint<1>, !firrtl.uint<1>
87+
}
88+
89+
// CHECK: firrtl.module private @ComplexInOut
90+
firrtl.module private @ComplexInOut(in %in: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
91+
%w1 = firrtl.wire : !firrtl.uint<1>
92+
%w2 = firrtl.wire : !firrtl.uint<1>
93+
%w3 = firrtl.wire : !firrtl.uint<1>
94+
firrtl.connect %w1, %in : !firrtl.uint<1>, !firrtl.uint<1>
95+
firrtl.connect %w2, %w1 : !firrtl.uint<1>, !firrtl.uint<1>
96+
firrtl.connect %w3, %w2 : !firrtl.uint<1>, !firrtl.uint<1>
97+
firrtl.connect %out, %w3 : !firrtl.uint<1>, !firrtl.uint<1>
98+
}
99+
100+
// Different direction - should not be grouped with above modules
101+
// CHECK: firrtl.module private @OutIn
102+
firrtl.module private @OutIn(out %out: !firrtl.uint<1>, in %in: !firrtl.uint<1>) {
103+
%wire = firrtl.wire : !firrtl.uint<1>
104+
firrtl.connect %wire, %in : !firrtl.uint<1>, !firrtl.uint<1>
105+
firrtl.connect %out, %wire : !firrtl.uint<1>, !firrtl.uint<1>
106+
}
107+
108+
// CHECK: firrtl.module @DirectionTest
109+
firrtl.module @DirectionTest(in %input: !firrtl.uint<1>, out %output1: !firrtl.uint<1>, out %output2: !firrtl.uint<1>, out %output3: !firrtl.uint<1>) {
110+
// CHECK: firrtl.instance simple @SimpleInOut
111+
%simple_in, %simple_out = firrtl.instance simple @SimpleInOut(in in: !firrtl.uint<1>, out out: !firrtl.uint<1>)
112+
113+
// CHECK: firrtl.instance complex @SimpleInOut
114+
%complex_in, %complex_out = firrtl.instance complex @ComplexInOut(in in: !firrtl.uint<1>, out out: !firrtl.uint<1>)
115+
116+
// This should remain unchanged due to different port order
117+
// CHECK: firrtl.instance outIn @OutIn
118+
%outIn_out, %outIn_in = firrtl.instance outIn @OutIn(out out: !firrtl.uint<1>, in in: !firrtl.uint<1>)
119+
120+
firrtl.connect %simple_in, %input : !firrtl.uint<1>, !firrtl.uint<1>
121+
firrtl.connect %complex_in, %input : !firrtl.uint<1>, !firrtl.uint<1>
122+
firrtl.connect %outIn_in, %input : !firrtl.uint<1>, !firrtl.uint<1>
123+
124+
firrtl.connect %output1, %simple_out : !firrtl.uint<1>, !firrtl.uint<1>
125+
firrtl.connect %output2, %complex_out : !firrtl.uint<1>, !firrtl.uint<1>
126+
firrtl.connect %output3, %outIn_out : !firrtl.uint<1>, !firrtl.uint<1>
127+
}
128+
}
129+
130+
// Test that instances participating in NLAs are not swapped
131+
// CHECK-LABEL: firrtl.circuit "NLATest"
132+
firrtl.circuit "NLATest" {
133+
// NLA that references an instance
134+
// CHECK: hw.hierpath private @nla [@NLATest::@large, @LargeNLA]
135+
hw.hierpath private @nla [@NLATest::@large, @LargeNLA]
136+
137+
// CHECK: firrtl.module private @SmallNLA
138+
firrtl.module private @SmallNLA(in %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>) {
139+
// Small simple module
140+
firrtl.connect %b, %a : !firrtl.uint<1>, !firrtl.uint<1>
141+
}
142+
143+
// CHECK: firrtl.module private @LargeNLA
144+
firrtl.module private @LargeNLA(in %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>) {
145+
// Large module with same interface as SmallNLA
146+
%wire1 = firrtl.wire : !firrtl.uint<1>
147+
%wire2 = firrtl.wire : !firrtl.uint<1>
148+
%wire3 = firrtl.wire : !firrtl.uint<1>
149+
%wire4 = firrtl.wire : !firrtl.uint<1>
150+
firrtl.connect %wire1, %a : !firrtl.uint<1>, !firrtl.uint<1>
151+
firrtl.connect %wire2, %wire1 : !firrtl.uint<1>, !firrtl.uint<1>
152+
firrtl.connect %wire3, %wire2 : !firrtl.uint<1>, !firrtl.uint<1>
153+
firrtl.connect %wire4, %wire3 : !firrtl.uint<1>, !firrtl.uint<1>
154+
firrtl.connect %b, %wire4 : !firrtl.uint<1>, !firrtl.uint<1>
155+
}
156+
157+
// CHECK: firrtl.module @NLATest
158+
firrtl.module @NLATest(in %input: !firrtl.uint<1>, out %output1: !firrtl.uint<1>, out %output2: !firrtl.uint<1>, out %output3: !firrtl.uint<1>, out %output4: !firrtl.uint<1>) {
159+
// This instance should remain as SmallNLA (already the smallest)
160+
// CHECK: firrtl.instance small @SmallNLA
161+
%small_a, %small_b = firrtl.instance small @SmallNLA(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
162+
163+
// This instance should NOT be swapped because it participates in an NLA
164+
// CHECK: firrtl.instance large sym @large {annotations = [{circt.nonlocal = @nla, class = "test"}]} @LargeNLA
165+
%large_a, %large_b = firrtl.instance large sym @large {annotations = [{circt.nonlocal = @nla, class = "test"}]} @LargeNLA(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
166+
167+
// This instance should be swapped because it does not participate in an NLA
168+
// CHECK: firrtl.instance another @SmallNLA
169+
%another_a, %another_b = firrtl.instance another @LargeNLA(in a: !firrtl.uint<1>, out b: !firrtl.uint<1>)
170+
171+
firrtl.connect %small_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
172+
firrtl.connect %large_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
173+
firrtl.connect %another_a, %input : !firrtl.uint<1>, !firrtl.uint<1>
174+
firrtl.connect %output1, %small_b : !firrtl.uint<1>, !firrtl.uint<1>
175+
firrtl.connect %output2, %large_b : !firrtl.uint<1>, !firrtl.uint<1>
176+
firrtl.connect %output3, %another_b : !firrtl.uint<1>, !firrtl.uint<1>
177+
}
178+
}
179+
180+
// Test that modules with different port names but same types are swapped
181+
// CHECK-LABEL: firrtl.circuit "PortNameTest"
182+
firrtl.circuit "PortNameTest" {
183+
// CHECK: firrtl.module private @SmallPortNames
184+
firrtl.module private @SmallPortNames(in %input: !firrtl.uint<1>, out %output: !firrtl.uint<1>) {
185+
// Small module with simple port names
186+
firrtl.connect %output, %input : !firrtl.uint<1>, !firrtl.uint<1>
187+
}
188+
189+
// CHECK: firrtl.module private @LargePortNames
190+
firrtl.module private @LargePortNames(in %data_in: !firrtl.uint<1>, out %data_out: !firrtl.uint<1>) {
191+
// Large module with different port names but same types
192+
%wire1 = firrtl.wire : !firrtl.uint<1>
193+
%wire2 = firrtl.wire : !firrtl.uint<1>
194+
%wire3 = firrtl.wire : !firrtl.uint<1>
195+
firrtl.connect %wire1, %data_in : !firrtl.uint<1>, !firrtl.uint<1>
196+
firrtl.connect %wire2, %wire1 : !firrtl.uint<1>, !firrtl.uint<1>
197+
firrtl.connect %wire3, %wire2 : !firrtl.uint<1>, !firrtl.uint<1>
198+
firrtl.connect %data_out, %wire3 : !firrtl.uint<1>, !firrtl.uint<1>
199+
}
200+
201+
// CHECK: firrtl.module @PortNameTest
202+
firrtl.module @PortNameTest(in %clk: !firrtl.clock, in %input: !firrtl.uint<1>, out %output1: !firrtl.uint<1>, out %output2: !firrtl.uint<1>) {
203+
// CHECK: firrtl.instance small @SmallPortNames
204+
%small_input, %small_output = firrtl.instance small @SmallPortNames(in input: !firrtl.uint<1>, out output: !firrtl.uint<1>)
205+
206+
// This should be swapped to SmallPortNames and port names should be updated
207+
// CHECK: firrtl.instance large @SmallPortNames(in input: !firrtl.uint<1>, out output: !firrtl.uint<1>)
208+
%large_data_in, %large_data_out = firrtl.instance large @LargePortNames(in data_in: !firrtl.uint<1>, out data_out: !firrtl.uint<1>)
209+
210+
// Connect inputs
211+
firrtl.connect %small_input, %input : !firrtl.uint<1>, !firrtl.uint<1>
212+
firrtl.connect %large_data_in, %input : !firrtl.uint<1>, !firrtl.uint<1>
213+
214+
// Connect outputs
215+
firrtl.connect %output1, %small_output : !firrtl.uint<1>, !firrtl.uint<1>
216+
firrtl.connect %output2, %large_data_out : !firrtl.uint<1>, !firrtl.uint<1>
217+
}
218+
}

test/Dialect/FIRRTL/Reduction/pattern-registration.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
// CHECK-DAG: firrtl-lower-chirrtl
2929
// CHECK-DAG: firrtl-lower-types
3030
// CHECK-DAG: firrtl-module-externalizer
31+
// CHECK-DAG: firrtl-module-swapper
3132
// CHECK-DAG: firrtl-operand0-forwarder
3233
// CHECK-DAG: firrtl-operand1-forwarder
3334
// CHECK-DAG: firrtl-operand2-forwarder

0 commit comments

Comments
 (0)