@@ -622,38 +622,79 @@ struct ConnectInvalidator : public Reduction {
622622 bool acceptSizeIncrease () const override { return true ; }
623623};
624624
625- // / A sample reduction pattern that removes FIRRTL annotations from ports and
626- // / operations.
625+ // / A reduction pattern that removes FIRRTL annotations from ports and
626+ // / operations. This generates one match per annotation and port annotation,
627+ // / allowing selective removal of individual annotations.
627628struct AnnotationRemover : public Reduction {
628629 void beforeReduction (mlir::ModuleOp op) override { nlaRemover.clear (); }
629630 void afterReduction (mlir::ModuleOp op) override { nlaRemover.remove (op); }
630- uint64_t match (Operation *op) override {
631- if (auto annos = op->getAttrOfType <ArrayAttr>(" annotations" ))
632- if (!annos.empty ())
633- return 1 ;
634- if (auto annos = op->getAttrOfType <ArrayAttr>(" portAnnotations" ))
635- if (llvm::any_of (annos.getAsRange <ArrayAttr>(),
636- [](auto portAnnos) { return !portAnnos.empty (); }))
637- return 1 ;
638- return 0 ;
631+
632+ void matches (Operation *op,
633+ llvm::function_ref<void (uint64_t , uint64_t )> addMatch) override {
634+ uint64_t matchId = 0 ;
635+
636+ // Generate matches for regular annotations
637+ if (auto annos = op->getAttrOfType <ArrayAttr>(" annotations" )) {
638+ for (auto anno : annos) {
639+ (void )anno;
640+ addMatch (1 , matchId++);
641+ }
642+ }
643+
644+ // Generate matches for port annotations
645+ if (auto portAnnos = op->getAttrOfType <ArrayAttr>(" portAnnotations" )) {
646+ for (auto portAnnoArray : portAnnos) {
647+ if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray)) {
648+ for (auto anno : portAnnoArrayAttr) {
649+ (void )anno;
650+ addMatch (1 , matchId++);
651+ }
652+ }
653+ }
654+ }
639655 }
640- LogicalResult rewrite (Operation *op) override {
641- auto emptyArray = ArrayAttr::get (op->getContext (), {});
642- if (auto annos = op->getAttr (" annotations" )) {
643- nlaRemover.markNLAsInAnnotation (annos);
644- op->setAttr (" annotations" , emptyArray);
656+
657+ LogicalResult rewriteMatches (Operation *op,
658+ ArrayRef<uint64_t > matches) override {
659+ // Convert matches to a set for fast lookup
660+ llvm::SmallDenseSet<uint64_t , 4 > matchesSet (matches.begin (), matches.end ());
661+
662+ // Lambda to process annotations and filter out matched ones
663+ uint64_t matchId = 0 ;
664+ auto processAnnotations =
665+ [&](ArrayRef<Attribute> annotations) -> ArrayAttr {
666+ SmallVector<Attribute> newAnnotations;
667+ for (auto anno : annotations) {
668+ if (!matchesSet.contains (matchId)) {
669+ newAnnotations.push_back (anno);
670+ } else {
671+ // Mark NLAs in the removed annotation for cleanup
672+ nlaRemover.markNLAsInAnnotation (anno);
673+ }
674+ matchId++;
675+ }
676+ return ArrayAttr::get (op->getContext (), newAnnotations);
677+ };
678+
679+ // Remove regular annotations
680+ if (auto annos = op->getAttrOfType <ArrayAttr>(" annotations" )) {
681+ op->setAttr (" annotations" , processAnnotations (annos.getValue ()));
645682 }
646- if (auto annos = op->getAttr (" portAnnotations" )) {
647- nlaRemover.markNLAsInAnnotation (annos);
648- auto attr = emptyArray;
649- if (isa<firrtl::InstanceOp>(op))
650- attr = ArrayAttr::get (
651- op->getContext (),
652- SmallVector<Attribute>(op->getNumResults (), emptyArray));
653- op->setAttr (" portAnnotations" , attr);
683+
684+ // Remove port annotations
685+ if (auto portAnnos = op->getAttrOfType <ArrayAttr>(" portAnnotations" )) {
686+ SmallVector<Attribute> newPortAnnos;
687+ for (auto portAnnoArrayAttr : portAnnos.getAsRange <ArrayAttr>()) {
688+ newPortAnnos.push_back (
689+ processAnnotations (portAnnoArrayAttr.getValue ()));
690+ }
691+ op->setAttr (" portAnnotations" ,
692+ ArrayAttr::get (op->getContext (), newPortAnnos));
654693 }
694+
655695 return success ();
656696 }
697+
657698 std::string getName () const override { return " annotation-remover" ; }
658699 NLARemover nlaRemover;
659700};
0 commit comments