Skip to content

Commit e589add

Browse files
authored
[Comb][circt-synth] Implement BalanceMux pass for optimizing mux chains (#9044)
This pass performs two main optimizations on mux chains: enhanced mux chain folding that converts chains of muxes with index comparisons into balanced mux tree, and priority encoder rebalancing that transforms linear chains into balanced tree structures reducing depth from O(n) to O(log n). 1. ``` # Before: Linear chain (O(n) depth) result = cond0 ? val0 : (cond1 ? val1 : (cond2 ? val2 : default)) # After: Balanced tree (O(log n) depth) left = cond0 ? val0 : val1 or_cond = cond0 | cond1 result = or_cond ? left : right ``` 2. ``` # Before: Mux chain with comparisons result = (index == 0) ? val0 : ((index == 1) ? val1 : ...) # After: Balanced mux tree result = index[n-1] ? (index[n-2] ? (index[n-3]? ...:...)) ```
1 parent ba41ee0 commit e589add

File tree

11 files changed

+596
-21
lines changed

11 files changed

+596
-21
lines changed

include/circt/Dialect/Comb/CombOps.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ Value createDynamicInject(OpBuilder &builder, Location loc, Value value,
8686
Value createInject(OpBuilder &builder, Location loc, Value value,
8787
unsigned offset, Value replacement);
8888

89+
/// Enum for mux chain folding styles.
90+
enum MuxChainWithComparisonFoldingStyle { None, BalancedMuxTree, ArrayGet };
91+
/// Mux chain folding that converts chains of muxes with index
92+
/// comparisons into array operations or balanced mux trees. `styleFn` is a
93+
/// callback that returns the desired folding style based on the index
94+
/// width and number of entries.
95+
bool foldMuxChainWithComparison(
96+
PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide,
97+
llvm::function_ref<MuxChainWithComparisonFoldingStyle(size_t indexWidth,
98+
size_t numEntries)>
99+
styleFn);
89100
} // namespace comb
90101
} // namespace circt
91102

include/circt/Dialect/Comb/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,16 @@ def CombIntRangeNarrowing : Pass<"comb-int-range-narrowing"> {
3434
}];
3535
}
3636

37+
def BalanceMux : Pass<"comb-balance-mux"> {
38+
let summary = "Balance and optimize mux chains";
39+
let description = [{
40+
Optimizes mux chains through enhanced folding and priority mux
41+
rebalancing. Converts index comparisons to arrays and rebalances
42+
linear chains into balanced trees, reducing depth from O(n) to O(log n).
43+
}];
44+
let options = [Option<
45+
"muxChainThreshold", "mux-chain-threshold", "unsigned", "16",
46+
"Minimum number of linear mux chains to trigger rebalancing">];
47+
}
48+
3749
#endif // CIRCT_DIALECT_COMB_PASSES_TD

include/circt/Dialect/Synth/Transforms/SynthPasses.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def PrintLongestPathAnalysis
120120
Option<"topModuleName", "top-module-name", "std::string", "",
121121
"Name of the top module to analyze (empty for automatic "
122122
"inference from instance graph)">];
123+
let dependentDialects = ["circt::comb::CombDialect", "circt::hw::HWDialect",
124+
"circt::synth::SynthDialect"];
123125
}
124126

125127
class ExternalSolverPass<string name> : Pass<name, "hw::HWModuleOp"> {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: circt-opt %s --pass-pipeline='builtin.module(synth-print-longest-path-analysis, hw.module(comb-balance-mux{mux-chain-threshold=4}))' -o %t.mlir | FileCheck %s --check-prefix=DEPTH_BEFORE
2+
// RUN: circt-opt %t.mlir --pass-pipeline='builtin.module(synth-print-longest-path-analysis)' | FileCheck %s --check-prefix=DEPTH_AFTER
3+
// RUN: circt-lec %t.mlir %s -c1=priority_mux_18_depth -c2=priority_mux_18_depth --shared-libs=%libz3 | FileCheck %s --check-prefix=MUX18_LEC
4+
// Check that balancing muxes reduces the longest path in a priority mux from O(n) to O(log n).
5+
// DEPTH_BEFORE-LABEL: priority_mux_18_depth
6+
// DEPTH_BEFORE: Maximum path delay: 17
7+
// DEPTH_AFTER-LABEL: priority_mux_18_depth
8+
// DEPTH_AFTER: Maximum path delay: 6
9+
// MUX18_LEC: c1 == c2
10+
hw.module @priority_mux_18_depth(in %cond0: i1, in %cond1: i1, in %cond2: i1, in %cond3: i1, in %cond4: i1, in %cond5: i1, in %cond6: i1, in %cond7: i1, in %cond8: i1, in %cond9: i1, in %cond10: i1, in %cond11: i1, in %cond12: i1, in %cond13: i1, in %cond14: i1, in %cond15: i1, in %cond16: i1, in %cond17: i1, out true_output: i5, out false_side: i5) {
11+
%c0_i5 = hw.constant 0 : i5
12+
%c1_i5 = hw.constant 1 : i5
13+
%c2_i5 = hw.constant 2 : i5
14+
%c3_i5 = hw.constant 3 : i5
15+
%c4_i5 = hw.constant 4 : i5
16+
%c5_i5 = hw.constant 5 : i5
17+
%c6_i5 = hw.constant 6 : i5
18+
%c7_i5 = hw.constant 7 : i5
19+
%c8_i5 = hw.constant 8 : i5
20+
%c9_i5 = hw.constant 9 : i5
21+
%c10_i5 = hw.constant 10 : i5
22+
%c11_i5 = hw.constant 11 : i5
23+
%c12_i5 = hw.constant 12 : i5
24+
%c13_i5 = hw.constant 13 : i5
25+
%c14_i5 = hw.constant 14 : i5
26+
%c15_i5 = hw.constant 15 : i5
27+
%c16_i5 = hw.constant 16 : i5
28+
%c17_i5 = hw.constant 17 : i5
29+
30+
%mux17_t = comb.mux %cond16, %c17_i5, %c0_i5 : i5
31+
%mux16_t = comb.mux %cond15, %mux17_t, %c16_i5 : i5
32+
%mux15_t = comb.mux %cond14, %mux16_t, %c15_i5 : i5
33+
%mux14_t = comb.mux %cond13, %mux15_t, %c14_i5 : i5
34+
%mux13_t = comb.mux %cond12, %mux14_t, %c13_i5 : i5
35+
%mux12_t = comb.mux %cond11, %mux13_t, %c12_i5 : i5
36+
%mux11_t = comb.mux %cond10, %mux12_t, %c11_i5 : i5
37+
%mux10_t = comb.mux %cond9, %mux11_t, %c10_i5 : i5
38+
%mux9_t = comb.mux %cond8, %mux10_t, %c9_i5 : i5
39+
%mux8_t = comb.mux %cond7, %mux9_t, %c8_i5 : i5
40+
%mux7_t = comb.mux %cond6, %mux8_t, %c7_i5 : i5
41+
%mux6_t = comb.mux %cond5, %mux7_t, %c6_i5 : i5
42+
%mux5_t = comb.mux %cond4, %mux6_t, %c5_i5 : i5
43+
%mux4_t = comb.mux %cond3, %mux5_t, %c4_i5 : i5
44+
%mux3_t = comb.mux %cond2, %mux4_t, %c3_i5 : i5
45+
%mux2_t = comb.mux %cond1, %mux3_t, %c2_i5 : i5
46+
%mux1_t = comb.mux %cond0, %mux2_t, %c1_i5 : i5
47+
48+
%mux17_f = comb.mux %cond16, %c0_i5, %c17_i5 : i5
49+
%mux16_f = comb.mux %cond15, %c16_i5, %mux17_f : i5
50+
%mux15_f = comb.mux %cond14, %c15_i5, %mux16_f : i5
51+
%mux14_f = comb.mux %cond13, %c14_i5, %mux15_f : i5
52+
%mux13_f = comb.mux %cond12, %c13_i5, %mux14_f : i5
53+
%mux12_f = comb.mux %cond11, %c12_i5, %mux13_f : i5
54+
%mux11_f = comb.mux %cond10, %c11_i5, %mux12_f : i5
55+
%mux10_f = comb.mux %cond9, %c10_i5, %mux11_f : i5
56+
%mux9_f = comb.mux %cond8, %c9_i5, %mux10_f : i5
57+
%mux8_f = comb.mux %cond7, %c8_i5, %mux9_f : i5
58+
%mux7_f = comb.mux %cond6, %c7_i5, %mux8_f : i5
59+
%mux6_f = comb.mux %cond5, %c6_i5, %mux7_f : i5
60+
%mux5_f = comb.mux %cond4, %c5_i5, %mux6_f : i5
61+
%mux4_f = comb.mux %cond3, %c4_i5, %mux5_f : i5
62+
%mux3_f = comb.mux %cond2, %c3_i5, %mux4_f : i5
63+
%mux2_f = comb.mux %cond1, %c2_i5, %mux3_f : i5
64+
%mux1_f = comb.mux %cond0, %c1_i5, %mux2_f : i5
65+
66+
hw.output %mux1_t, %mux1_f : i5, i5
67+
}
68+
69+
// RUN: circt-lec %t.mlir %s -c1=index_to_balanced_mux -c2=index_to_balanced_mux --shared-libs=%libz3 | FileCheck %s --check-prefix=INDEX_TO_BALANCED_MUX_LEC
70+
// DEPTH_BEFORE-LABEL: Longest Path Analysis result for "index_to_balanced_mux"
71+
// DEPTH_BEFORE: Maximum path delay: 11
72+
// DEPTH_AFTER-LABEL: Longest Path Analysis result for "index_to_balanced_mux"
73+
// DEPTH_AFTER: Maximum path delay: 3
74+
// INDEX_TO_BALANCED_MUX_LEC: c1 == c2
75+
hw.module @index_to_balanced_mux(in %index: i3, out result: i8) {
76+
// Values to select from based on index
77+
%a = hw.constant 10 : i8
78+
%b = hw.constant 20 : i8
79+
%c = hw.constant 30 : i8
80+
%d = hw.constant 40 : i8
81+
%e = hw.constant 50 : i8
82+
%f = hw.constant 60 : i8
83+
%g = hw.constant 70 : i8
84+
%default = hw.constant 0 : i8
85+
86+
// Index comparison constants
87+
%c0 = hw.constant 0 : i3
88+
%c1 = hw.constant 1 : i3
89+
%c2 = hw.constant 2 : i3
90+
%c3 = hw.constant 3 : i3
91+
%c4 = hw.constant 4 : i3
92+
%c5 = hw.constant 5 : i3
93+
%c6 = hw.constant 6 : i3
94+
95+
%eq0 = comb.icmp eq %index, %c0 : i3
96+
%eq1 = comb.icmp eq %index, %c1 : i3
97+
%eq2 = comb.icmp eq %index, %c2 : i3
98+
%eq3 = comb.icmp eq %index, %c3 : i3
99+
%eq4 = comb.icmp eq %index, %c4 : i3
100+
%eq5 = comb.icmp eq %index, %c5 : i3
101+
%eq6 = comb.icmp eq %index, %c6 : i3
102+
103+
%mux6 = comb.mux %eq6, %g, %default : i8
104+
%mux5 = comb.mux %eq5, %f, %mux6 : i8
105+
%mux4 = comb.mux %eq4, %e, %mux5 : i8
106+
%mux3 = comb.mux %eq3, %d, %mux4 : i8
107+
%mux2 = comb.mux %eq2, %c, %mux3 : i8
108+
%mux1 = comb.mux %eq1, %b, %mux2 : i8
109+
%result = comb.mux %eq0, %a, %mux1 : i8
110+
111+
hw.output %result : i8
112+
}
113+

lib/Dialect/Comb/CombFolds.cpp

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "circt/Dialect/HW/HWAttributes.h"
1111
#include "circt/Dialect/HW/HWOps.h"
1212
#include "circt/Support/Naming.h"
13+
#include "mlir/IR/Diagnostics.h"
1314
#include "mlir/IR/Matchers.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "llvm/ADT/SetVector.h"
@@ -1973,13 +1974,17 @@ getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
19731974
/// Given a mux, check to see if the "on true" value (or "on false" value if
19741975
/// isFalseSide=true) is a mux tree with the same condition. This allows us
19751976
/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1976-
/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1977-
/// allows synthesis tools to do more interesting optimizations.
1977+
/// `array_get (array_create(A, B, C), VAL)` or a balanced mux tree which is far
1978+
/// more compact and allows synthesis tools to do more interesting
1979+
/// optimizations.
19781980
///
19791981
/// This returns false if we cannot form the mux tree (or do not want to) and
19801982
/// returns true if the mux was replaced.
1981-
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
1982-
PatternRewriter &rewriter) {
1983+
bool comb::foldMuxChainWithComparison(
1984+
PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide,
1985+
llvm::function_ref<MuxChainWithComparisonFoldingStyle(size_t indexWidth,
1986+
size_t numEntries)>
1987+
styleFn) {
19831988
// Get the index value being compared. Later we check to see if it is
19841989
// compared to a constant with the right predicate.
19851990
auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
@@ -2039,24 +2044,16 @@ static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
20392044
nextTreeValue = getTreeValue(nextMux);
20402045
}
20412046

2042-
// We need to have more than three values to create an array. This is an
2043-
// arbitrary threshold which is saying that one or two muxes together is ok,
2044-
// but three should be folded.
2045-
if (valuesFound.size() < 3)
2046-
return false;
2047-
2048-
// If the array is greater that 9 bits, it will take over 512 elements and
2049-
// it will be too large for a single expression.
20502047
auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2051-
if (indexWidth >= 9)
2048+
2049+
if (indexWidth > 20)
2050+
return false; // Too big to make a table.
2051+
2052+
auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2053+
if (foldingStyle == MuxChainWithComparisonFoldingStyle::None)
20522054
return false;
20532055

2054-
// Next we need to see if the values are dense-ish. We don't want to have
2055-
// a tremendous number of replicated entries in the array. Some sparsity is
2056-
// ok though, so we require the table to be at least 5/8 utilized.
20572056
uint64_t tableSize = 1ULL << indexWidth;
2058-
if (valuesFound.size() < (tableSize * 5) / 8)
2059-
return false; // Not dense enough.
20602057

20612058
// Ok, we're going to do the transformation, start by building the table
20622059
// filled with the "otherwise" value.
@@ -2071,13 +2068,25 @@ static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
20712068
table[idx] = elt.second;
20722069
}
20732070

2071+
if (foldingStyle == MuxChainWithComparisonFoldingStyle::BalancedMuxTree) {
2072+
SmallVector<Value> bits;
2073+
comb::extractBits(rewriter, indexValue, bits);
2074+
auto result = constructMuxTree(rewriter, rootMux->getLoc(), bits, table,
2075+
nextTreeValue);
2076+
replaceOpAndCopyNamehint(rewriter, rootMux, result);
2077+
return true;
2078+
}
2079+
2080+
assert(foldingStyle == MuxChainWithComparisonFoldingStyle::ArrayGet &&
2081+
"unknown folding style");
2082+
20742083
// The hw.array_create operation has the operand list in unintuitive order
20752084
// with a[0] stored as the last element, not the first.
20762085
std::reverse(table.begin(), table.end());
20772086

20782087
// Build the array_create and the array_get.
20792088
auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2080-
auto array = hw::ArrayCreateOp::create(rewriter, fusedLoc, table);
2089+
auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
20812090
replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
20822091
indexValue);
20832092
return true;
@@ -2376,6 +2385,22 @@ struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
23762385
PatternRewriter &rewriter) const override;
23772386
};
23782387

2388+
MuxChainWithComparisonFoldingStyle
2389+
foldToArrayCreateOnlyWhenDense(size_t indexWidth, size_t numEntries) {
2390+
// If the array is greater that 9 bits, it will take over 512 elements and
2391+
// it will be too large for a single expression.
2392+
if (indexWidth >= 9 || numEntries < 3)
2393+
return MuxChainWithComparisonFoldingStyle::None;
2394+
2395+
// Next we need to see if the values are dense-ish. We don't want to have
2396+
// a tremendous number of replicated entries in the array. Some sparsity is
2397+
// ok though, so we require the table to be at least 5/8 utilized.
2398+
uint64_t tableSize = 1ULL << indexWidth;
2399+
if (numEntries >= tableSize * 5 / 8)
2400+
return MuxChainWithComparisonFoldingStyle::ArrayGet;
2401+
return MuxChainWithComparisonFoldingStyle::None;
2402+
}
2403+
23792404
LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
23802405
PatternRewriter &rewriter) const {
23812406
if (isOpTriviallyRecursive(op))
@@ -2530,7 +2555,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
25302555
}
25312556

25322557
// Check to see if we can fold a mux tree into an array_create/get pair.
2533-
if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2558+
if (foldMuxChainWithComparison(rewriter, op, /*isFalse*/ true,
2559+
foldToArrayCreateOnlyWhenDense))
25342560
return success();
25352561
}
25362562

@@ -2545,7 +2571,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
25452571
}
25462572

25472573
// Check to see if we can fold a mux tree into an array_create/get pair.
2548-
if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2574+
if (foldMuxChainWithComparison(rewriter, op, /*isFalseSide*/ false,
2575+
foldToArrayCreateOnlyWhenDense))
25492576
return success();
25502577
}
25512578

0 commit comments

Comments
 (0)