Skip to content

Commit ff3aa23

Browse files
committed
[FIRRTL] Make annotation reduction create one match per annotation
Instead of removing all annotations on an operation, make the `AnnotationRemover` reduction produce a match for every single annotation and port annotation on an operation. This allows the reducer to remove a subset of annotations from an operation, which is very valuable in practice.
1 parent bafe0f4 commit ff3aa23

File tree

2 files changed

+89
-32
lines changed

2 files changed

+89
-32
lines changed

lib/Dialect/FIRRTL/FIRRTLReductions.cpp

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -622,38 +622,79 @@ struct ConnectInvalidator : public Reduction {
622622
bool acceptSizeIncrease() const override { return true; }
623623
};
624624

625-
/// A sample reduction pattern that removes FIRRTL annotations from ports and
626-
/// operations.
625+
/// A reduction pattern that removes FIRRTL annotations from ports and
626+
/// operations. This generates one match per annotation and port annotation,
627+
/// allowing selective removal of individual annotations.
627628
struct AnnotationRemover : public Reduction {
628629
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
629630
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
630-
uint64_t match(Operation *op) override {
631-
if (auto annos = op->getAttrOfType<ArrayAttr>("annotations"))
632-
if (!annos.empty())
633-
return 1;
634-
if (auto annos = op->getAttrOfType<ArrayAttr>("portAnnotations"))
635-
if (llvm::any_of(annos.getAsRange<ArrayAttr>(),
636-
[](auto portAnnos) { return !portAnnos.empty(); }))
637-
return 1;
638-
return 0;
631+
632+
void matches(Operation *op,
633+
llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
634+
uint64_t matchId = 0;
635+
636+
// Generate matches for regular annotations
637+
if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
638+
for (auto anno : annos) {
639+
(void)anno;
640+
addMatch(1, matchId++);
641+
}
642+
}
643+
644+
// Generate matches for port annotations
645+
if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
646+
for (auto portAnnoArray : portAnnos) {
647+
if (auto portAnnoArrayAttr = dyn_cast<ArrayAttr>(portAnnoArray)) {
648+
for (auto anno : portAnnoArrayAttr) {
649+
(void)anno;
650+
addMatch(1, matchId++);
651+
}
652+
}
653+
}
654+
}
639655
}
640-
LogicalResult rewrite(Operation *op) override {
641-
auto emptyArray = ArrayAttr::get(op->getContext(), {});
642-
if (auto annos = op->getAttr("annotations")) {
643-
nlaRemover.markNLAsInAnnotation(annos);
644-
op->setAttr("annotations", emptyArray);
656+
657+
LogicalResult rewriteMatches(Operation *op,
658+
ArrayRef<uint64_t> matches) override {
659+
// Convert matches to a set for fast lookup
660+
llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
661+
662+
// Lambda to process annotations and filter out matched ones
663+
uint64_t matchId = 0;
664+
auto processAnnotations =
665+
[&](ArrayRef<Attribute> annotations) -> ArrayAttr {
666+
SmallVector<Attribute> newAnnotations;
667+
for (auto anno : annotations) {
668+
if (!matchesSet.contains(matchId)) {
669+
newAnnotations.push_back(anno);
670+
} else {
671+
// Mark NLAs in the removed annotation for cleanup
672+
nlaRemover.markNLAsInAnnotation(anno);
673+
}
674+
matchId++;
675+
}
676+
return ArrayAttr::get(op->getContext(), newAnnotations);
677+
};
678+
679+
// Remove regular annotations
680+
if (auto annos = op->getAttrOfType<ArrayAttr>("annotations")) {
681+
op->setAttr("annotations", processAnnotations(annos.getValue()));
645682
}
646-
if (auto annos = op->getAttr("portAnnotations")) {
647-
nlaRemover.markNLAsInAnnotation(annos);
648-
auto attr = emptyArray;
649-
if (isa<firrtl::InstanceOp>(op))
650-
attr = ArrayAttr::get(
651-
op->getContext(),
652-
SmallVector<Attribute>(op->getNumResults(), emptyArray));
653-
op->setAttr("portAnnotations", attr);
683+
684+
// Remove port annotations
685+
if (auto portAnnos = op->getAttrOfType<ArrayAttr>("portAnnotations")) {
686+
SmallVector<Attribute> newPortAnnos;
687+
for (auto portAnnoArrayAttr : portAnnos.getAsRange<ArrayAttr>()) {
688+
newPortAnnos.push_back(
689+
processAnnotations(portAnnoArrayAttr.getValue()));
690+
}
691+
op->setAttr("portAnnotations",
692+
ArrayAttr::get(op->getContext(), newPortAnnos));
654693
}
694+
655695
return success();
656696
}
697+
657698
std::string getName() const override { return "annotation-remover"; }
658699
NLARemover nlaRemover;
659700
};
Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
1+
// RUN: circt-reduce %s --test /usr/bin/env --test-arg grep --test-arg -q --test-arg "class = \"a\"" --include annotation-remover --keep-best=0 | FileCheck %s --check-prefixes=CHECK,CHECK-A
2+
// RUN: circt-reduce %s --test /usr/bin/env --test-arg grep --test-arg -q --test-arg "class = \"x\"" --include annotation-remover --keep-best=0 | FileCheck %s --check-prefixes=CHECK,CHECK-X
3+
14
// UNSUPPORTED: system-windows
25
// See https://github.com/llvm/circt/issues/4129
3-
// RUN: circt-reduce %s --test /usr/bin/env --test-arg grep --test-arg -q --test-arg "%anotherWire = firrtl.wire" --keep-best=0 --include annotation-remover | FileCheck %s
46

5-
firrtl.circuit "Foo" {
6-
// CHECK: firrtl.module @Foo
7-
// CHECK: %anotherWire = firrtl.wire
8-
// CHECK-NOT: annotations
9-
firrtl.module @Foo() {
10-
%oneWire = firrtl.wire : !firrtl.uint<1>
11-
%anotherWire = firrtl.wire {annotations = [{a}]} : !firrtl.uint<1>
7+
// This test verifies that the AnnotationRemover can selectively remove individual annotations.
8+
// The test uses grep to look for annotation "a", so the reducer should keep that annotation
9+
// but remove annotations "b" and "c" that don't match the grep pattern.
10+
11+
firrtl.circuit "TestAnnotationRemover" {
12+
// CHECK: firrtl.module @TestAnnotationRemover
13+
// CHECK-A-SAME: [{class = "a"}]
14+
firrtl.module @TestAnnotationRemover(
15+
in %a: !firrtl.uint<1> [
16+
{class = "a"},
17+
{class = "b"},
18+
{class = "c"}
19+
]
20+
) {
21+
// CHECK: firrtl.wire
22+
// CHECK-X-SAME: [{class = "x"}]
23+
%someWire = firrtl.wire {annotations = [
24+
{class = "x"},
25+
{class = "y"},
26+
{class = "z"}
27+
]} : !firrtl.uint<8>
1228
}
1329
}

0 commit comments

Comments
 (0)