-
Notifications
You must be signed in to change notification settings - Fork 415
[HW] Add HWVectorization pass #9222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[HW] Add HWVectorization pass #9222
Conversation
59a27df to
7dfbad0
Compare
|
Hi everyone, just a gentle ping on this PR. It has been open for a while, and I wanted to check whether there is anything we can do on our side to help move the review forward. Many thanks! |
| bit &operator=(const bit &other); | ||
| bool operator==(const bit &other) const; | ||
|
|
||
| bool left_adjacent(const bit &other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please use camelBack (as noted in MLIR style guide https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide)
|
|
||
| Block &block = module.getBody().front(); | ||
| auto outputOp = dyn_cast<hw::OutputOp>(block.getTerminator()); | ||
| if (!outputOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if-statement is not necessary as hw::OutputOp is guaranteed by a verifier.
|
|
||
| bool containsLLHD = false; | ||
| module.walk([&](mlir::Operation *op) { | ||
| if (op->getDialect()->getNamespace() == "llhd") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this gives up when there is llhd?
| } else if (auto andOp = dyn_cast<comb::AndOp>(op)) { | ||
| Value lhs = andOp.getInputs()[0]; | ||
| Value rhs = andOp.getInputs()[1]; | ||
| if (isa_and_nonnull<hw::ConstantOp>(rhs.getDefiningOp())) | ||
| return findBitSource(lhs, bitIndex, depth + 1); | ||
| if (isa_and_nonnull<hw::ConstantOp>(lhs.getDefiningOp())) | ||
| return findBitSource(rhs, bitIndex, depth + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I'm following the code correctly but these parts seem not correct. Is it necessary to check the value of the constant? Also can't we simply treat and/or/xor as the source op? here?
| bool vectorizer::cleanup_dead_ops(Block &block) { | ||
| bool overallChanged = false; | ||
| bool changedInIteration = true; | ||
| while (changedInIteration) { | ||
| changedInIteration = false; | ||
| llvm::SmallVector<Operation *, 16> deadOps; | ||
| for (Operation &op : block) { | ||
| if (op.use_empty() && !op.hasTrait<mlir::OpTrait::IsTerminator>()) { | ||
| deadOps.push_back(&op); | ||
| } | ||
| } | ||
| if (!deadOps.empty()) { | ||
| changedInIteration = true; | ||
| overallChanged = true; | ||
| for (Operation *op : deadOps) { | ||
| op->erase(); | ||
| } | ||
| } | ||
| } | ||
| return overallChanged; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use https://github.com/llvm/llvm-project/blob/db557bee1e2c128e77805deb86c1f364b5c29e70/mlir/lib/Transforms/Utils/RegionUtils.cpp#L495? There are few issues around side-effecting op and O(N^2) fixpoint iterations here so would be nice to simply use a library function.
| llvm::DenseSet<mlir::Value> sources; | ||
| for (const auto &[_, bit] : bits) { | ||
| if (!sources.contains(bit.source)) | ||
| sources.insert(bit.source); | ||
| if (sources.size() >= 2) | ||
| return false; | ||
| } | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: using DenseSet is certainly overkill here, e.g.:
| llvm::DenseSet<mlir::Value> sources; | |
| for (const auto &[_, bit] : bits) { | |
| if (!sources.contains(bit.source)) | |
| sources.insert(bit.source); | |
| if (sources.size() >= 2) | |
| return false; | |
| } | |
| return true; | |
| mlir::Value source; | |
| for (const auto &[_, bit] : bits) { | |
| if(source && source != bit.source) return false; | |
| source = bit.source; | |
| } | |
| return true; |
a756a7e to
ae4c6b6
Compare
ae4c6b6 to
1f3df90
Compare
dcbe132 to
80f2389
Compare
|
Hi @uenoku, Thank you very much for the review and for pointing out the issues with the previous approach. It was really helpful. I’ve reworked findBitSource to keep it strictly structural again, and moved all boolean reasoning into a separate helper (isBitConstant). This helper is intentionally limited: it only proves constants through structural traversal and identity propagation (e.g., and(x, 1) and or(x, 0)), and does not attempt general boolean simplification. The helper is used only to recognize identity masks in and/or, which allows handling the mux-like pattern in test_mux without turning findBitSource into a semantic evaluator. Please let me know if this direction looks more reasonable to you, or if you’d prefer an even more conservative restriction. Thanks again for the review! |
|
Hi @uenoku, hope you’re doing well and had a great holiday season! I just wanted to gently follow up on this PR. It’s been rebased, all checks are passing, and it addresses the feedback about the isBitConstant helper. Happy to make any further changes if needed. Thanks a lot! |
| if (!allBitsHaveSameSource() || bits.empty()) { | ||
| return nullptr; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (!allBitsHaveSameSource() || bits.empty()) { | |
| return nullptr; | |
| } | |
| if (!allBitsHaveSameSource() || bits.empty()) | |
| return nullptr; | |
|
|
||
| bool BitArray::allBitsHaveSameSource() const { | ||
| mlir::Value source; | ||
| for (const auto &[_, bit] : bits) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that iteration order of DenseMap is non-deterministic, is there any place that depends on the order?
| return true; | ||
| } | ||
|
|
||
| Bit BitArray::getBit(int n) { return bits[n]; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it fine to mutate bits here? If n is not registered yet, it returns nullptr but is it handled in a caller?
| IRRewriter rewriter(module.getContext()); | ||
| bool changed = false; | ||
|
|
||
| for (Value oldOutputVal : outputOp->getOperands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only apply transformation for an output value? (though it makes sense as a first step, maybe you might want to consider other operations like hw.instance/seq.compreg etc.
| } | ||
|
|
||
| bool Bit::operator==(const Bit &other) const { | ||
| return source == other.source and index == other.index; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for the consistency with other places.
| return source == other.source and index == other.index; | |
| return source == other.source && index == other.index; |
| return false; | ||
| } | ||
|
|
||
| bool Vectorizer::canVectorizeStructurally(mlir::Value output) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current pattern separates analysis (can* functions) from transformation (apply* functions), which requires analyzing the IR twice and makes it very difficult to understand what's been validated before mutation occurs.
Please consider combining these into try* methods that return LogicalResult:
LogicalResult tryStructuralVectorization(OpBuilder &builder, Value value);
LogicalResult tryPartialVectorization(OpBuilder &builder, Value value);if (succeeded(tryLinearVectorization(oldOutputVal, sourceInput)))
continue;
if (succeeded(tryReverseVectorization(rewriter, oldOutputVal, sourceInput)))
continue;
if (succeeded(tryStructuralVectorization(rewriter, oldOutputVal)))
continue;
if (succeeded(tryPartialVectorization(rewriter, oldOutputVal)))
continue;| while ((i - len) >= 0) { | ||
| Value nextBitSource = findBitSource(oldOutputVal, i - len); | ||
| auto nextExtractOp = | ||
| dyn_cast_or_null<comb::ExtractOp>(nextBitSource.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| dyn_cast_or_null<comb::ExtractOp>(nextBitSource.getDefiningOp()); | |
| nextBitSource.getDefiningOp<comb::ExtractOp>(); |
| cone.insert(val); | ||
|
|
||
| Operation *definingOp = val.getDefiningOp(); | ||
| if (!definingOp || isa<BlockArgument>(val) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: !definingOp means isa<BlockArgument>(val)
| } | ||
| } | ||
|
|
||
| bool Vectorizer::isSafeSharedValue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could elaborate why this is safe to share? Could you leave comments? I feel it always return true.
|
|
||
| if (auto *op = val.getDefiningOp()) { | ||
| for (auto operand : op->getOperands()) { | ||
| if (!isSafeSharedValue(operand, visited)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also because visited is initialized every time at line 593 I think this function seems to visit entire use-def chain.
| struct HWVectorizationPass | ||
| : public hw::impl::HWVectorizationBase<HWVectorizationPass> { | ||
|
|
||
| void getDependentDialects(mlir::DialectRegistry ®istry) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please define this in tablegen. Also SV dialect is not necessary i think.
| return false; | ||
|
|
||
| if (auto c = dyn_cast<hw::ConstantOp>(defOp)) { | ||
| if (bitIndex < c.getValue().getBitWidth()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this condition ever happen?
uenoku
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I apologize for the long delay - I'm still working through understanding what the pass does, and it's taking me some time.
Since this PR is quite large and does non-trivial IR transformations, it might be much easier to review and merge quickly if you could split it into smaller PRs. Something like:
- Pass boilerplate and basic infrastructure
BitArraydata structure +ExtractOppreprocessing + linear vectorization (applyLinearVectorization)- Reverse and mixed permutation vectorization +
ConcatOphandling - Logical op preprocessing (and/or/xor) + structural vectorization
- Partial vectorization
Keeping each PR under ~300 LOC would make it much easier to review, provide targeted feedback, and merge incrementally. It would also help me (and other reviewers) understand the design better by seeing it build up piece by piece.
What do you think?
This patch introduces the HWVectorization pass, which identifies bitwise patterns in hardware modules that can be represented as vectorized operations instead of per-bit logic.
The pass aims to simplify the IR by grouping related scalar bit operations (such as
comb.extractandcomb.concat) into higher-level vector constructs likecomb.reverse,comb.replicate, or direct multi-bitcomb.and,comb.or, andcomb.xor.The pass scans each hw.module and identifies groups of bit-level operations that can be merged into vector-level constructs. This version supports several key patterns based on bit-level dataflow analysis and structural analysis.
This patch was co-authored by @RosaUlisses.
Supported transformations include:
1. Linear concatenations (identity):
Pattern: Bits are extracted in ascending order (identity permutation) and concatenated.
Transformation: The entire
comb.concatchain is replaced with the original input vector.2. Bit reversal:
Pattern: Bits are extracted in descending (reverse) order and concatenated.
Transformation: The chain is replaced with a single
comb.reverse.3. Structural Patterns (e.g., Vectorized Mux)
Pattern: Isomorphic, bit-parallel logic cones are detected. For example, a scalarized mux structure that uses a replicated i1 control signal for each bit.
Transformation: The replicated scalar operations are collapsed into equivalent vector-level operations (e.g.,
comb.replicate,comb.and,comb.xor,comb.or).4. Partial Vectorization (Chunking):
Pattern: The pass identifies contiguous sub-ranges (chunks) that can be vectorized independently, even if the entire bus cannot be.
Transformation: The pass vectorizes the identifiable chunks (e.g., a linear chunk) and leaves the remaining scalar or structural logic as another chunk, then concatenates the chunks back together.
Patterns not transformed
The pass does not modify modules with cross-bit dependencies or non-linear control flows.
For example: