Skip to content

Commit 26a6b6e

Browse files
uenokuMax-astro
andauthored
[Synth] Enhance LowerVariadic pass with timing-aware optimization (#9086)
Extend the LowerVariadic pass to handle all commutative operations (and, or, xor, mul, add) from the comb dialect, not just AndInverter ops. The new implementation uses a delay-aware algorithm that builds balanced binary trees by combining values with the earliest arrival times first, minimizing critical path delay. This is implemented with a priority queue that orders values by their arrival time as computed by the IncrementalLongestPathAnalysis. Co-authored-by: Max Zhou <[email protected]>
1 parent b614d5b commit 26a6b6e

File tree

9 files changed

+300
-83
lines changed

9 files changed

+300
-83
lines changed

include/circt/Dialect/Synth/Transforms/SynthPasses.td

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,19 @@ def GenericLutMapper : CutRewriterPassBase<"synth-generic-lut-mapper",
7878
}
7979

8080
def LowerVariadic : Pass<"synth-lower-variadic", "hw::HWModuleOp"> {
81-
let summary = "Lower variadic AndInverter operations to binary AndInverter";
81+
let summary = "Lower variadic operations to binary operations";
82+
let description = [{
83+
This pass lowers variadic operations to binary operations using a
84+
delay-aware algorithm. For commutative operations, it builds a balanced
85+
tree by combining values with the earliest arrival times first to minimize
86+
the critical path.
87+
}];
88+
let options = [
89+
ListOption<"opNames", "op-names", "std::string",
90+
"Specify operation names to lower (empty means all)">,
91+
Option<"timingAware", "timing-aware", "bool", "true",
92+
"Lower operators with timing information">
93+
];
8294
}
8395

8496
def LowerWordToBits : Pass<"synth-lower-word-to-bits", "hw::HWModuleOp"> {

include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ struct SynthOptimizationPipelineOptions
7272
PassOptions::Option<bool> disableWordToBits{
7373
*this, "disable-word-to-bits",
7474
llvm::cl::desc("Disable LowerWordToBits pass"), llvm::cl::init(false)};
75+
76+
PassOptions::Option<bool> timingAware{
77+
*this, "timing-aware",
78+
llvm::cl::desc("Lower operators in a timing-aware fashion"),
79+
llvm::cl::init(false)};
7580
};
7681

7782
//===----------------------------------------------------------------------===//
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// REQUIRES: libz3
2+
// REQUIRES: circt-lec-jit
3+
4+
// RUN: circt-opt %s -convert-synth-to-comb -o %t.before.mlir
5+
// RUN: circt-opt %s -synth-lower-variadic -convert-synth-to-comb -o %t.after.mlir
6+
// RUN: circt-lec %t.before.mlir %t.after.mlir -c1=AndInverter -c2=AndInverter --shared-libs=%libz3 | FileCheck %s --check-prefix=AND_INVERTER_LEC
7+
// AND_INVERTER_LEC: c1 == c2
8+
hw.module @AndInverter(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, in %f: i2, in %g: i2, out o1: i2) {
9+
%0 = synth.aig.and_inv %d, not %e : i2
10+
%1 = synth.aig.and_inv not %c, not %0, %f : i2
11+
%2 = synth.aig.and_inv %a, not %b, not %1, %g : i2
12+
hw.output %2 : i2
13+
}
14+
15+
// RUN: circt-lec %t.before.mlir %t.after.mlir -c1=VariadicCombOps -c2=VariadicCombOps --shared-libs=%libz3 | FileCheck %s --check-prefix=VARIADIC_COMB_OPS_LEC
16+
// VARIADIC_COMB_OPS_LEC: c1 == c2
17+
hw.module @VariadicCombOps(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, in %f: i2,
18+
out out_and: i2, out out_or: i2, out out_xor: i2) {
19+
%0 = comb.and %a, %b, %c, %d, %e, %f : i2
20+
%1 = comb.or %a, %b, %c, %d, %e, %f : i2
21+
%2 = comb.xor %a, %b, %c, %d, %e, %f : i2
22+
hw.output %0, %1, %2 : i2, i2, i2
23+
}

lib/Dialect/Synth/Transforms/LowerVariadic.cpp

Lines changed: 176 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This pass lowers variadic AndInverter operations to binary AndInverter
10-
// operations.
9+
// This pass lowers variadic operations to binary operations using a
10+
// delay-aware algorithm for commutative operations.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "circt/Dialect/Comb/CombDialect.h"
15+
#include "circt/Dialect/Comb/CombOps.h"
1416
#include "circt/Dialect/HW/HWOps.h"
17+
#include "circt/Dialect/Synth/Analysis/LongestPathAnalysis.h"
1518
#include "circt/Dialect/Synth/SynthOps.h"
1619
#include "circt/Dialect/Synth/Transforms/SynthPasses.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "mlir/IR/OpDefinition.h"
21+
#include "llvm/ADT/PointerIntPair.h"
22+
#include "llvm/ADT/PriorityQueue.h"
1823

1924
#define DEBUG_TYPE "synth-lower-variadic"
2025

@@ -29,79 +34,190 @@ using namespace circt;
2934
using namespace synth;
3035

3136
//===----------------------------------------------------------------------===//
32-
// Rewrite patterns
37+
// Lower Variadic pass
3338
//===----------------------------------------------------------------------===//
3439

3540
namespace {
36-
static Value lowerVariadicAndInverterOp(aig::AndInverterOp op,
37-
OperandRange operands,
38-
ArrayRef<bool> inverts,
39-
PatternRewriter &rewriter) {
40-
switch (operands.size()) {
41-
case 0:
42-
assert(0 && "cannot be called with empty operand range");
43-
break;
44-
case 1:
45-
if (inverts[0])
46-
return aig::AndInverterOp::create(rewriter, op.getLoc(), operands[0],
47-
true);
48-
else
49-
return operands[0];
50-
case 2:
51-
return aig::AndInverterOp::create(rewriter, op.getLoc(), operands[0],
52-
operands[1], inverts[0], inverts[1]);
53-
default:
54-
auto firstHalf = operands.size() / 2;
55-
auto lhs =
56-
lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
57-
inverts.take_front(firstHalf), rewriter);
58-
auto rhs =
59-
lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
60-
inverts.drop_front(firstHalf), rewriter);
61-
return aig::AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
62-
}
6341

64-
return Value();
65-
}
42+
/// Helper class for delay-aware variadic operation lowering.
43+
/// Stores a value along with its arrival time for priority queue ordering.
44+
class ValueWithArrivalTime {
45+
/// The value and an optional inversion flag packed together.
46+
/// The inversion flag is used for AndInverterOp lowering.
47+
llvm::PointerIntPair<Value, 1, bool> value;
6648

67-
struct VariadicOpConversion : OpRewritePattern<aig::AndInverterOp> {
68-
using OpRewritePattern<aig::AndInverterOp>::OpRewritePattern;
69-
LogicalResult matchAndRewrite(aig::AndInverterOp op,
70-
PatternRewriter &rewriter) const override {
71-
if (op.getInputs().size() <= 2)
72-
return failure();
49+
/// The arrival time (delay) of this value in the circuit.
50+
int64_t arrivalTime;
7351

74-
// TODO: This is a naive implementation that creates a balanced binary tree.
75-
// We can improve by analyzing the dataflow and creating a tree that
76-
// improves the critical path or area.
77-
rewriter.replaceOp(op,
78-
lowerVariadicAndInverterOp(op, op.getOperands(),
79-
op.getInverted(), rewriter));
80-
return success();
81-
}
82-
};
52+
/// Value numbering for deterministic ordering when arrival times are equal.
53+
/// This ensures consistent results across runs when multiple values have
54+
/// the same delay.
55+
size_t valueNumbering = 0;
8356

84-
} // namespace
57+
public:
58+
ValueWithArrivalTime(Value value, int64_t arrivalTime, bool invert,
59+
size_t valueNumbering)
60+
: value(value, invert), arrivalTime(arrivalTime),
61+
valueNumbering(valueNumbering) {}
8562

86-
static void populateLowerVariadicPatterns(RewritePatternSet &patterns) {
87-
patterns.add<VariadicOpConversion>(patterns.getContext());
88-
}
63+
Value getValue() const { return value.getPointer(); }
64+
bool isInverted() const { return value.getInt(); }
8965

90-
//===----------------------------------------------------------------------===//
91-
// Lower Variadic pass
92-
//===----------------------------------------------------------------------===//
66+
/// Comparison operator for priority queue. Values with earlier arrival times
67+
/// have higher priority. When arrival times are equal, use value numbering
68+
/// for determinism.
69+
bool operator>(const ValueWithArrivalTime &other) const {
70+
return arrivalTime > other.arrivalTime ||
71+
(arrivalTime == other.arrivalTime &&
72+
valueNumbering > other.valueNumbering);
73+
}
74+
};
9375

94-
namespace {
9576
struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
77+
using LowerVariadicBase::LowerVariadicBase;
9678
void runOnOperation() override;
9779
};
80+
9881
} // namespace
9982

83+
/// Construct a balanced binary tree from a variadic operation using a
84+
/// delay-aware algorithm. This function builds the tree by repeatedly combining
85+
/// the two values with the earliest arrival times, which minimizes the critical
86+
/// path delay.
87+
static LogicalResult replaceWithBalancedTree(
88+
IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
89+
Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
90+
llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
91+
createBinaryOp) {
92+
// Min-heap priority queue ordered by arrival time.
93+
// Values with earlier arrival times are processed first.
94+
llvm::PriorityQueue<ValueWithArrivalTime, std::vector<ValueWithArrivalTime>,
95+
std::greater<ValueWithArrivalTime>>
96+
queue;
97+
98+
// Counter for deterministic ordering when arrival times are equal.
99+
size_t valueNumber = 0;
100+
101+
auto push = [&](Value value, bool invert) {
102+
int64_t delay = 0;
103+
// If analysis is available, use it to compute the delay.
104+
// If not available, use zero delay and `valueNumber` will be used instead.
105+
if (analysis) {
106+
auto result = analysis->getMaxDelay(value);
107+
if (failed(result))
108+
return failure();
109+
delay = *result;
110+
}
111+
ValueWithArrivalTime entry(value, delay, invert, valueNumber++);
112+
queue.push(entry);
113+
return success();
114+
};
115+
116+
// Enqueue all operands with their arrival times and inversion flags.
117+
for (size_t i = 0, e = op->getNumOperands(); i < e; ++i)
118+
if (failed(push(op->getOperand(i), isInverted(op->getOpOperand(i)))))
119+
return failure();
120+
121+
// Build balanced tree by repeatedly combining the two earliest values.
122+
// This greedy approach minimizes the maximum depth of late-arriving signals.
123+
while (queue.size() >= 2) {
124+
auto lhs = queue.top();
125+
queue.pop();
126+
auto rhs = queue.top();
127+
queue.pop();
128+
// Create and enqueue the combined value.
129+
if (failed(push(createBinaryOp(lhs, rhs), /*inverted=*/false)))
130+
return failure();
131+
}
132+
133+
// Get the final result and replace the original operation.
134+
auto result = queue.top().getValue();
135+
rewriter.replaceOp(op, result);
136+
return success();
137+
}
138+
100139
void LowerVariadicPass::runOnOperation() {
101-
RewritePatternSet patterns(&getContext());
102-
populateLowerVariadicPatterns(patterns);
103-
mlir::FrozenRewritePatternSet frozen(std::move(patterns));
140+
// Topologically sort operations in graph regions to ensure operands are
141+
// defined before uses.
142+
if (failed(synth::topologicallySortGraphRegionBlocks(
143+
getOperation(), [](Value, Operation *op) -> bool {
144+
return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
145+
op->getDialect());
146+
})))
147+
return signalPassFailure();
148+
149+
// Get longest path analysis if timing-aware lowering is enabled.
150+
synth::IncrementalLongestPathAnalysis *analysis = nullptr;
151+
if (timingAware.getValue())
152+
analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
153+
154+
auto moduleOp = getOperation();
155+
156+
// Build set of operation names to lower if specified.
157+
SmallVector<OperationName> names;
158+
for (const auto &name : opNames)
159+
names.push_back(OperationName(name, &getContext()));
160+
161+
// Return true if the operation should be lowered.
162+
auto shouldLower = [&](Operation *op) {
163+
// If no names specified, lower all variadic ops.
164+
if (names.empty())
165+
return true;
166+
return llvm::find(names, op->getName()) != names.end();
167+
};
168+
169+
mlir::IRRewriter rewriter(&getContext());
170+
rewriter.setListener(analysis);
171+
172+
auto result = moduleOp->walk([&](Operation *op) {
173+
// Skip operations that don't need lowering or are already binary.
174+
if (!shouldLower(op) || op->getNumOperands() <= 2)
175+
return WalkResult::advance();
176+
177+
rewriter.setInsertionPoint(op);
178+
179+
// Handle AndInverterOp specially to preserve inversion flags.
180+
if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
181+
auto result = replaceWithBalancedTree(
182+
analysis, rewriter, op,
183+
// Check if each operand is inverted.
184+
[&](OpOperand &operand) {
185+
return andInverterOp.isInverted(operand.getOperandNumber());
186+
},
187+
// Create binary AndInverterOp with inversion flags.
188+
[&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
189+
return rewriter.create<aig::AndInverterOp>(
190+
op->getLoc(), lhs.getValue(), rhs.getValue(), lhs.isInverted(),
191+
rhs.isInverted());
192+
});
193+
return result.succeeded() ? WalkResult::advance()
194+
: WalkResult::interrupt();
195+
}
196+
197+
// Handle commutative operations (and, or, xor, mul, add, etc.) using
198+
// delay-aware lowering to minimize critical path.
199+
if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
200+
op->hasTrait<OpTrait::IsCommutative>()) {
201+
auto result = replaceWithBalancedTree(
202+
analysis, rewriter, op,
203+
// No inversion flags for standard commutative operations.
204+
[](OpOperand &) { return false; },
205+
// Create binary operation with the same operation type.
206+
[&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
207+
OperationState state(op->getLoc(), op->getName());
208+
state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
209+
state.addTypes(op->getResult(0).getType());
210+
auto *newOp = Operation::create(state);
211+
rewriter.insert(newOp);
212+
return newOp->getResult(0);
213+
});
214+
return result.succeeded() ? WalkResult::advance()
215+
: WalkResult::interrupt();
216+
}
217+
218+
return WalkResult::advance();
219+
});
104220

105-
if (failed(mlir::applyPatternsGreedily(getOperation(), frozen)))
221+
if (result.wasInterrupted())
106222
return signalPassFailure();
107223
}

lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,22 @@ using namespace circt::synth;
3636

3737
/// Helper function to populate additional legal ops for partial legalization.
3838
template <typename... AllowedOpTy>
39-
static void partiallyLegalizeCombToSynth(SmallVectorImpl<std::string> &ops) {
39+
static void addOpName(SmallVectorImpl<std::string> &ops) {
4040
(ops.push_back(AllowedOpTy::getOperationName().str()), ...);
4141
}
42-
42+
template <typename... OpToLowerTy>
43+
static std::unique_ptr<Pass> createLowerVariadicPass(bool timingAware) {
44+
LowerVariadicOptions options;
45+
addOpName<OpToLowerTy...>(options.opNames);
46+
options.timingAware = timingAware;
47+
return createLowerVariadic(options);
48+
}
4349
void circt::synth::buildCombLoweringPipeline(
4450
OpPassManager &pm, const CombLoweringPipelineOptions &options) {
4551
{
4652
if (!options.disableDatapath) {
53+
// Lower variadic Mul into a binary op to enable datapath lowering.
54+
pm.addPass(createLowerVariadicPass<comb::MulOp>(options.timingAware));
4755
pm.addPass(createConvertCombToDatapath());
4856
pm.addPass(createSimpleCanonicalizerPass());
4957
if (options.synthesisStrategy == OptimizationStrategyTiming)
@@ -55,10 +63,9 @@ void circt::synth::buildCombLoweringPipeline(
5563
}
5664
// Partially legalize Comb, then run CSE and canonicalization.
5765
circt::ConvertCombToSynthOptions convOptions;
58-
partiallyLegalizeCombToSynth<comb::AndOp, comb::OrOp, comb::XorOp,
59-
comb::MuxOp, comb::ICmpOp, hw::ArrayGetOp,
60-
hw::ArraySliceOp, hw::ArrayCreateOp,
61-
hw::ArrayConcatOp, hw::AggregateConstantOp>(
66+
addOpName<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp, comb::ICmpOp,
67+
hw::ArrayGetOp, hw::ArraySliceOp, hw::ArrayCreateOp,
68+
hw::ArrayConcatOp, hw::AggregateConstantOp>(
6269
convOptions.additionalLegalOps);
6370
pm.addPass(circt::createConvertCombToSynth(convOptions));
6471
}
@@ -69,6 +76,18 @@ void circt::synth::buildCombLoweringPipeline(
6976
comb::BalanceMuxOptions balanceOptions{OptimizationStrategyTiming ? 16 : 64};
7077
pm.addPass(comb::createBalanceMux(balanceOptions));
7178

79+
// Lower variadic ops before running full lowering to target IR.
80+
if (options.targetIR.getValue() == TargetIR::AIG) {
81+
// For AIG, lower variadic XoR since AIG cannot keep variadic
82+
// representation.
83+
pm.addPass(createLowerVariadicPass<comb::XorOp>(options.timingAware));
84+
} else if (options.targetIR.getValue() == TargetIR::MIG) {
85+
// For MIG, lower variadic And, Or, and Xor since MIG cannot keep variadic
86+
// representation.
87+
pm.addPass(createLowerVariadicPass<comb::AndOp, comb::OrOp, comb::XorOp>(
88+
options.timingAware));
89+
}
90+
7291
pm.addPass(circt::hw::createHWAggregateToComb());
7392
circt::ConvertCombToSynthOptions convOptions;
7493
convOptions.targetIR = options.targetIR.getValue() == TargetIR::AIG
@@ -83,7 +102,7 @@ void circt::synth::buildCombLoweringPipeline(
83102
void circt::synth::buildSynthOptimizationPipeline(
84103
OpPassManager &pm, const SynthOptimizationPipelineOptions &options) {
85104

86-
pm.addPass(synth::createLowerVariadic());
105+
pm.addPass(createLowerVariadicPass(options.timingAware));
87106

88107
// LowerWordToBits may not be scalable for large designs so conditionally
89108
// disable it. It's also worth considering keeping word-level representation

0 commit comments

Comments
 (0)