66//
77// ===----------------------------------------------------------------------===//
88//
9- // This pass lowers variadic AndInverter operations to binary AndInverter
10- // operations.
9+ // This pass lowers variadic operations to binary operations using a
10+ // delay-aware algorithm for commutative operations.
1111//
1212// ===----------------------------------------------------------------------===//
1313
14+ #include " circt/Dialect/Comb/CombDialect.h"
15+ #include " circt/Dialect/Comb/CombOps.h"
1416#include " circt/Dialect/HW/HWOps.h"
17+ #include " circt/Dialect/Synth/Analysis/LongestPathAnalysis.h"
1518#include " circt/Dialect/Synth/SynthOps.h"
1619#include " circt/Dialect/Synth/Transforms/SynthPasses.h"
17- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
20+ #include " mlir/IR/OpDefinition.h"
21+ #include " llvm/ADT/PointerIntPair.h"
22+ #include " llvm/ADT/PriorityQueue.h"
1823
1924#define DEBUG_TYPE " synth-lower-variadic"
2025
@@ -29,79 +34,190 @@ using namespace circt;
2934using namespace synth ;
3035
3136// ===----------------------------------------------------------------------===//
32- // Rewrite patterns
37+ // Lower Variadic pass
3338// ===----------------------------------------------------------------------===//
3439
3540namespace {
36- static Value lowerVariadicAndInverterOp (aig::AndInverterOp op,
37- OperandRange operands,
38- ArrayRef<bool > inverts,
39- PatternRewriter &rewriter) {
40- switch (operands.size ()) {
41- case 0 :
42- assert (0 && " cannot be called with empty operand range" );
43- break ;
44- case 1 :
45- if (inverts[0 ])
46- return aig::AndInverterOp::create (rewriter, op.getLoc (), operands[0 ],
47- true );
48- else
49- return operands[0 ];
50- case 2 :
51- return aig::AndInverterOp::create (rewriter, op.getLoc (), operands[0 ],
52- operands[1 ], inverts[0 ], inverts[1 ]);
53- default :
54- auto firstHalf = operands.size () / 2 ;
55- auto lhs =
56- lowerVariadicAndInverterOp (op, operands.take_front (firstHalf),
57- inverts.take_front (firstHalf), rewriter);
58- auto rhs =
59- lowerVariadicAndInverterOp (op, operands.drop_front (firstHalf),
60- inverts.drop_front (firstHalf), rewriter);
61- return aig::AndInverterOp::create (rewriter, op.getLoc (), lhs, rhs);
62- }
6341
64- return Value ();
65- }
42+ // / Helper class for delay-aware variadic operation lowering.
43+ // / Stores a value along with its arrival time for priority queue ordering.
44+ class ValueWithArrivalTime {
45+ // / The value and an optional inversion flag packed together.
46+ // / The inversion flag is used for AndInverterOp lowering.
47+ llvm::PointerIntPair<Value, 1 , bool > value;
6648
67- struct VariadicOpConversion : OpRewritePattern<aig::AndInverterOp> {
68- using OpRewritePattern<aig::AndInverterOp>::OpRewritePattern;
69- LogicalResult matchAndRewrite (aig::AndInverterOp op,
70- PatternRewriter &rewriter) const override {
71- if (op.getInputs ().size () <= 2 )
72- return failure ();
49+ // / The arrival time (delay) of this value in the circuit.
50+ int64_t arrivalTime;
7351
74- // TODO: This is a naive implementation that creates a balanced binary tree.
75- // We can improve by analyzing the dataflow and creating a tree that
76- // improves the critical path or area.
77- rewriter.replaceOp (op,
78- lowerVariadicAndInverterOp (op, op.getOperands (),
79- op.getInverted (), rewriter));
80- return success ();
81- }
82- };
52+ // / Value numbering for deterministic ordering when arrival times are equal.
53+ // / This ensures consistent results across runs when multiple values have
54+ // / the same delay.
55+ size_t valueNumbering = 0 ;
8356
84- } // namespace
57+ public:
58+ ValueWithArrivalTime (Value value, int64_t arrivalTime, bool invert,
59+ size_t valueNumbering)
60+ : value(value, invert), arrivalTime(arrivalTime),
61+ valueNumbering (valueNumbering) {}
8562
86- static void populateLowerVariadicPatterns (RewritePatternSet &patterns) {
87- patterns.add <VariadicOpConversion>(patterns.getContext ());
88- }
63+ Value getValue () const { return value.getPointer (); }
64+ bool isInverted () const { return value.getInt (); }
8965
90- // ===----------------------------------------------------------------------===//
91- // Lower Variadic pass
92- // ===----------------------------------------------------------------------===//
66+ // / Comparison operator for priority queue. Values with earlier arrival times
67+ // / have higher priority. When arrival times are equal, use value numbering
68+ // / for determinism.
69+ bool operator >(const ValueWithArrivalTime &other) const {
70+ return arrivalTime > other.arrivalTime ||
71+ (arrivalTime == other.arrivalTime &&
72+ valueNumbering > other.valueNumbering );
73+ }
74+ };
9375
94- namespace {
9576struct LowerVariadicPass : public impl ::LowerVariadicBase<LowerVariadicPass> {
77+ using LowerVariadicBase::LowerVariadicBase;
9678 void runOnOperation () override ;
9779};
80+
9881} // namespace
9982
83+ // / Construct a balanced binary tree from a variadic operation using a
84+ // / delay-aware algorithm. This function builds the tree by repeatedly combining
85+ // / the two values with the earliest arrival times, which minimizes the critical
86+ // / path delay.
87+ static LogicalResult replaceWithBalancedTree (
88+ IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
89+ Operation *op, llvm::function_ref<bool (OpOperand &)> isInverted,
90+ llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
91+ createBinaryOp) {
92+ // Min-heap priority queue ordered by arrival time.
93+ // Values with earlier arrival times are processed first.
94+ llvm::PriorityQueue<ValueWithArrivalTime, std::vector<ValueWithArrivalTime>,
95+ std::greater<ValueWithArrivalTime>>
96+ queue;
97+
98+ // Counter for deterministic ordering when arrival times are equal.
99+ size_t valueNumber = 0 ;
100+
101+ auto push = [&](Value value, bool invert) {
102+ int64_t delay = 0 ;
103+ // If analysis is available, use it to compute the delay.
104+ // If not available, use zero delay and `valueNumber` will be used instead.
105+ if (analysis) {
106+ auto result = analysis->getMaxDelay (value);
107+ if (failed (result))
108+ return failure ();
109+ delay = *result;
110+ }
111+ ValueWithArrivalTime entry (value, delay, invert, valueNumber++);
112+ queue.push (entry);
113+ return success ();
114+ };
115+
116+ // Enqueue all operands with their arrival times and inversion flags.
117+ for (size_t i = 0 , e = op->getNumOperands (); i < e; ++i)
118+ if (failed (push (op->getOperand (i), isInverted (op->getOpOperand (i)))))
119+ return failure ();
120+
121+ // Build balanced tree by repeatedly combining the two earliest values.
122+ // This greedy approach minimizes the maximum depth of late-arriving signals.
123+ while (queue.size () >= 2 ) {
124+ auto lhs = queue.top ();
125+ queue.pop ();
126+ auto rhs = queue.top ();
127+ queue.pop ();
128+ // Create and enqueue the combined value.
129+ if (failed (push (createBinaryOp (lhs, rhs), /* inverted=*/ false )))
130+ return failure ();
131+ }
132+
133+ // Get the final result and replace the original operation.
134+ auto result = queue.top ().getValue ();
135+ rewriter.replaceOp (op, result);
136+ return success ();
137+ }
138+
100139void LowerVariadicPass::runOnOperation () {
101- RewritePatternSet patterns (&getContext ());
102- populateLowerVariadicPatterns (patterns);
103- mlir::FrozenRewritePatternSet frozen (std::move (patterns));
140+ // Topologically sort operations in graph regions to ensure operands are
141+ // defined before uses.
142+ if (failed (synth::topologicallySortGraphRegionBlocks (
143+ getOperation (), [](Value, Operation *op) -> bool {
144+ return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
145+ op->getDialect ());
146+ })))
147+ return signalPassFailure ();
148+
149+ // Get longest path analysis if timing-aware lowering is enabled.
150+ synth::IncrementalLongestPathAnalysis *analysis = nullptr ;
151+ if (timingAware.getValue ())
152+ analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
153+
154+ auto moduleOp = getOperation ();
155+
156+ // Build set of operation names to lower if specified.
157+ SmallVector<OperationName> names;
158+ for (const auto &name : opNames)
159+ names.push_back (OperationName (name, &getContext ()));
160+
161+ // Return true if the operation should be lowered.
162+ auto shouldLower = [&](Operation *op) {
163+ // If no names specified, lower all variadic ops.
164+ if (names.empty ())
165+ return true ;
166+ return llvm::find (names, op->getName ()) != names.end ();
167+ };
168+
169+ mlir::IRRewriter rewriter (&getContext ());
170+ rewriter.setListener (analysis);
171+
172+ auto result = moduleOp->walk ([&](Operation *op) {
173+ // Skip operations that don't need lowering or are already binary.
174+ if (!shouldLower (op) || op->getNumOperands () <= 2 )
175+ return WalkResult::advance ();
176+
177+ rewriter.setInsertionPoint (op);
178+
179+ // Handle AndInverterOp specially to preserve inversion flags.
180+ if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
181+ auto result = replaceWithBalancedTree (
182+ analysis, rewriter, op,
183+ // Check if each operand is inverted.
184+ [&](OpOperand &operand) {
185+ return andInverterOp.isInverted (operand.getOperandNumber ());
186+ },
187+ // Create binary AndInverterOp with inversion flags.
188+ [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
189+ return rewriter.create <aig::AndInverterOp>(
190+ op->getLoc (), lhs.getValue (), rhs.getValue (), lhs.isInverted (),
191+ rhs.isInverted ());
192+ });
193+ return result.succeeded () ? WalkResult::advance ()
194+ : WalkResult::interrupt ();
195+ }
196+
197+ // Handle commutative operations (and, or, xor, mul, add, etc.) using
198+ // delay-aware lowering to minimize critical path.
199+ if (isa_and_nonnull<comb::CombDialect>(op->getDialect ()) &&
200+ op->hasTrait <OpTrait::IsCommutative>()) {
201+ auto result = replaceWithBalancedTree (
202+ analysis, rewriter, op,
203+ // No inversion flags for standard commutative operations.
204+ [](OpOperand &) { return false ; },
205+ // Create binary operation with the same operation type.
206+ [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
207+ OperationState state (op->getLoc (), op->getName ());
208+ state.addOperands (ValueRange{lhs.getValue (), rhs.getValue ()});
209+ state.addTypes (op->getResult (0 ).getType ());
210+ auto *newOp = Operation::create (state);
211+ rewriter.insert (newOp);
212+ return newOp->getResult (0 );
213+ });
214+ return result.succeeded () ? WalkResult::advance ()
215+ : WalkResult::interrupt ();
216+ }
217+
218+ return WalkResult::advance ();
219+ });
104220
105- if (failed ( mlir::applyPatternsGreedily ( getOperation (), frozen) ))
221+ if (result. wasInterrupted ( ))
106222 return signalPassFailure ();
107223}
0 commit comments