77// ===----------------------------------------------------------------------===//
88
99#include " circt/Dialect/FIRRTL/FIRRTLReductions.h"
10+ #include " circt/Dialect/FIRRTL/AnnotationDetails.h"
1011#include " circt/Dialect/FIRRTL/CHIRRTLDialect.h"
12+ #include " circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
1113#include " circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
1214#include " circt/Dialect/FIRRTL/FIRRTLOps.h"
1315#include " circt/Dialect/FIRRTL/FIRRTLUtils.h"
2527
2628using namespace mlir ;
2729using namespace circt ;
30+ using namespace firrtl ;
2831
2932// ===----------------------------------------------------------------------===//
3033// Utilities
@@ -179,6 +182,8 @@ struct NLARemover {
179182// Reduction patterns
180183// ===----------------------------------------------------------------------===//
181184
185+ namespace {
186+
182187// / A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
183188struct FIRRTLModuleExternalizer : public OpReduction <firrtl::FModuleOp> {
184189 void beforeReduction (mlir::ModuleOp op) override {
@@ -1118,6 +1123,182 @@ struct ModuleNameSanitizer : OpReduction<firrtl::CircuitOp> {
11181123 bool isOneShot () const override { return true ; }
11191124};
11201125
1126+ // / A reduction pattern that handles MustDedup annotations by replacing all
1127+ // / module names in a dedup group with a single module name. This helps reduce
1128+ // / the IR by consolidating module references that are required to be identical.
1129+ // /
1130+ // / The pattern works by:
1131+ // / 1. Finding all MustDeduplicateAnnotation annotations on the circuit
1132+ // / 2. For each dedup group, using the first module as the canonical name
1133+ // / 3. Replacing all instance references to other modules in the group with
1134+ // / references to the canonical module
1135+ // / 4. Removing the non-canonical modules from the circuit
1136+ // / 5. Removing the processed MustDedup annotation
1137+ // /
1138+ // / This reduction is particularly useful for reducing large circuits where
1139+ // / multiple modules are known to be identical but haven't been deduplicated
1140+ // / yet.
1141+ struct ForceDedup : public OpReduction <CircuitOp> {
1142+ void beforeReduction (mlir::ModuleOp op) override {
1143+ symbols.clear ();
1144+ nlaRemover.clear ();
1145+ }
1146+ void afterReduction (mlir::ModuleOp op) override { nlaRemover.remove (op); }
1147+
1148+ // / Collect all MustDedup annotations and create matches for each dedup group.
1149+ void matches (CircuitOp circuitOp,
1150+ llvm::function_ref<void (uint64_t , uint64_t )> addMatch) override {
1151+ auto annotations = AnnotationSet (circuitOp);
1152+ for (auto [annoIdx, anno] : llvm::enumerate (annotations)) {
1153+ if (!anno.isClass (mustDedupAnnoClass))
1154+ continue ;
1155+
1156+ auto modulesAttr = anno.getMember <ArrayAttr>(" modules" );
1157+ if (!modulesAttr)
1158+ continue ;
1159+
1160+ // Each dedup group gets its own match with benefit proportional to group
1161+ // size.
1162+ uint64_t benefit = modulesAttr.size ();
1163+ addMatch (benefit, annoIdx);
1164+ }
1165+ }
1166+
1167+ LogicalResult rewriteMatches (CircuitOp circuitOp,
1168+ ArrayRef<uint64_t > matches) override {
1169+ auto *context = circuitOp->getContext ();
1170+ NLATable nlaTable (circuitOp);
1171+ hw::InnerSymbolTableCollection innerSymTables;
1172+ auto annotations = AnnotationSet (circuitOp);
1173+ SmallVector<Annotation> newAnnotations;
1174+
1175+ for (auto [annoIdx, anno] : llvm::enumerate (annotations)) {
1176+ // Check if this annotation was selected.
1177+ if (!llvm::is_contained (matches, annoIdx)) {
1178+ newAnnotations.push_back (anno);
1179+ continue ;
1180+ }
1181+ auto modulesAttr = anno.getMember <ArrayAttr>(" modules" );
1182+ assert (anno.isClass (mustDedupAnnoClass) && modulesAttr &&
1183+ modulesAttr.size () >= 2 );
1184+
1185+ // Extract module names from the dedup group.
1186+ SmallVector<StringAttr> moduleNames;
1187+ for (auto moduleRef : modulesAttr.getAsRange <StringAttr>()) {
1188+ // Parse "~CircuitName|ModuleName" format.
1189+ auto refStr = moduleRef.getValue ();
1190+ auto pipePos = refStr.find (' |' );
1191+ if (pipePos != StringRef::npos && pipePos + 1 < refStr.size ()) {
1192+ auto moduleName = refStr.substr (pipePos + 1 );
1193+ moduleNames.push_back (StringAttr::get (context, moduleName));
1194+ }
1195+ }
1196+
1197+ // Simply drop the annotation if there's only one module.
1198+ if (moduleNames.size () < 2 )
1199+ continue ;
1200+
1201+ // Replace all instances and references to other modules with the
1202+ // first module.
1203+ replaceModuleReferences (circuitOp, moduleNames, nlaTable, innerSymTables);
1204+ nlaRemover.markNLAsInAnnotation (anno.getAttr ());
1205+ }
1206+ if (newAnnotations.size () == annotations.size ())
1207+ return failure ();
1208+
1209+ // Update circuit annotations.
1210+ AnnotationSet newAnnoSet (newAnnotations, context);
1211+ newAnnoSet.applyToOperation (circuitOp);
1212+ return success ();
1213+ }
1214+
1215+ std::string getName () const override { return " firrtl-force-dedup" ; }
1216+ bool acceptSizeIncrease () const override { return true ; }
1217+
1218+ private:
1219+ // / Replace all references to modules in the dedup group with the canonical
1220+ // / module name
1221+ void replaceModuleReferences (CircuitOp circuitOp,
1222+ ArrayRef<StringAttr> moduleNames,
1223+ NLATable &nlaTable,
1224+ hw::InnerSymbolTableCollection &innerSymTables) {
1225+ auto *tableOp = SymbolTable::getNearestSymbolTable (circuitOp);
1226+ auto &symbolTable = symbols.getSymbolTable (tableOp);
1227+ auto *context = circuitOp->getContext ();
1228+ auto innerRefs = hw::InnerRefNamespace{symbolTable, innerSymTables};
1229+
1230+ // Collect the modules.
1231+ FModuleLike canonicalModule;
1232+ SmallVector<FModuleLike> modulesToReplace;
1233+ for (auto name : moduleNames) {
1234+ if (auto mod = symbolTable.lookup <FModuleLike>(name)) {
1235+ if (!canonicalModule)
1236+ canonicalModule = mod;
1237+ else
1238+ modulesToReplace.push_back (mod);
1239+ }
1240+ }
1241+ if (modulesToReplace.empty ())
1242+ return ;
1243+
1244+ // Replace all instance references.
1245+ auto canonicalName = canonicalModule.getModuleNameAttr ();
1246+ auto canonicalRef = FlatSymbolRefAttr::get (canonicalName);
1247+ circuitOp.walk ([&](InstanceOp instOp) {
1248+ auto moduleName = instOp.getModuleNameAttr ().getAttr ();
1249+ if (llvm::is_contained (moduleNames, moduleName) &&
1250+ moduleName != canonicalName) {
1251+ instOp.setModuleNameAttr (canonicalRef);
1252+ instOp.setPortNamesAttr (canonicalModule.getPortNamesAttr ());
1253+ }
1254+ });
1255+
1256+ // Update NLAs to reference the canonical module instead of modules being
1257+ // removed using NLATable for better performance.
1258+ for (auto oldMod : modulesToReplace) {
1259+ SmallVector<hw::HierPathOp> nlaOps (
1260+ nlaTable.lookup (oldMod.getModuleNameAttr ()));
1261+ for (auto nlaOp : nlaOps) {
1262+ nlaTable.erase (nlaOp);
1263+ StringAttr oldModName = oldMod.getModuleNameAttr ();
1264+ StringAttr newModName = canonicalName;
1265+ SmallVector<Attribute, 4 > newPath;
1266+ for (auto nameRef : nlaOp.getNamepath ()) {
1267+ if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
1268+ if (ref.getModule () == oldModName) {
1269+ auto oldInst = innerRefs.lookupOp <FInstanceLike>(ref);
1270+ ref = hw::InnerRefAttr::get (newModName, ref.getName ());
1271+ auto newInst = innerRefs.lookupOp <FInstanceLike>(ref);
1272+ if (oldInst && newInst) {
1273+ oldModName = oldInst.getReferencedModuleNameAttr ();
1274+ newModName = newInst.getReferencedModuleNameAttr ();
1275+ }
1276+ }
1277+ newPath.push_back (ref);
1278+ } else if (cast<FlatSymbolRefAttr>(nameRef).getAttr () == oldModName) {
1279+ newPath.push_back (FlatSymbolRefAttr::get (newModName));
1280+ } else {
1281+ newPath.push_back (nameRef);
1282+ }
1283+ }
1284+ nlaOp.setNamepathAttr (ArrayAttr::get (context, newPath));
1285+ nlaTable.addNLA (nlaOp);
1286+ }
1287+ }
1288+
1289+ // Mark NLAs in modules to be removed.
1290+ for (auto module : modulesToReplace) {
1291+ nlaRemover.markNLAsInOperation (module );
1292+ module ->erase ();
1293+ }
1294+ }
1295+
1296+ ::detail::SymbolCache symbols;
1297+ NLARemover nlaRemover;
1298+ };
1299+
1300+ } // namespace
1301+
11211302// ===----------------------------------------------------------------------===//
11221303// Reduction Registration
11231304// ===----------------------------------------------------------------------===//
@@ -1129,6 +1310,7 @@ void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
11291310 // prioritized). For example, things that can knock out entire modules while
11301311 // being cheap should be tried first (and thus have higher benefit), before
11311312 // trying to tweak operands of individual arithmetic ops.
1313+ patterns.add <ForceDedup, 31 >();
11321314 patterns.add <PassReduction, 30 >(
11331315 getContext (),
11341316 firrtl::createDropName ({/* preserveMode=*/ PreserveValues::None}), false ,
0 commit comments