[water] propagate vector shapes along index expressions#1160
[water] propagate vector shapes along index expressions#1160
Conversation
Several places in the index expression analysis have been directly overriding pre-existing lattice values, in particular for: - block arguments during forward initialization, with bottom; - terminator operands during backward initialization, with bottom; - op operands during backward initialization, with the value returned by the interface call. This may and effectively does move back on the lattice, which may cause problems with analysis convergence or consistency. Furthermore, since the lattice are shared between forward and backward analyses, initialization of one of them may override the previously set lattice value. The two former cases were attempts to defend against the awkward upstream setup where lattices are initialized to the top value for block arguments during per-op initialization. This isn't needed since we have a mechanism inside `setToEntry/ExitState` to avoid that and initialize to bottom during per-op initizliation. Drop that. The latter case is due to interface implementations overriding the pre-existing values instead of joining them. While this may be justified in some cases, we need to ensure the lattice doesn't go back from values already available. More generally, initialization triggers some propagation and so does lattice sharing between analyses. Therefore, initialization must also join with the pre-existing values and potentially report errors. Guard against lattice direction changes in debug mode. Make sure error messages are emitted if joining lattices in initialization results in the top state. Some of these are difficult to trigger directly without complex test scaffolding, so we make an intentional decision to only test a subset of messages. In particular, initalization for backward pass will run after (1) default-initialization for forward that visits all operations; (2) initializaiton for forward; (3) default-initialization for backward that visits all operations again; and triggerring an error message means none of the above should have detected the conflict first. A potential solution to that is to meld initalization and visitation functions, but it comes with higher code complexity of visitation functions and potentially higher cost. Signed-off-by: Alex Zinenko <git@ozinenko.com>
MMA ops require specific index expressions set during initialization of the analysis. We shouldn't allow these values to be overwritten, even with indexes defined by other MMA ops to avoid indexing and hit the right instruction. If the index for an MMA operand specified at its definition point differs from what MMA expects, a yet-to-be-ported pywave pass will introduce a reshape operation. A more holistic solution would inject "virtual reshape" operations and update the propagation to work with sets of all possible index expressions rather than immediately detecting conflicts. A separate analysis would then choose between the possible expressions based on some criteria, materialize meaningful reshapes and drop the noop ones. Signed-off-by: Alex Zinenko <git@ozinenko.com> Assited-by: Claude Opus 4.6 <noreply@antropic.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
These two quantities are related: vector shapes indicate the (desired) shape of the data processed by each operation instance after distribution and unrolling, index expressions the position of the that in the larger tensor. In some cases, they may influence each other during propagation. In future, we should consider folding them into one data structure but for now be consistent with pywave. Signed-off-by: Alex Zinenko <git@ozinenko.com>
Water Code Coverage |
There was a problem hiding this comment.
Pull request overview
This PR extends Wave index-expression dataflow to also carry and join per-dimension vector shapes alongside index expressions, and adjusts MMA behavior to keep MMA-specific index expressions stable (for pywave compatibility).
Changes:
- Extend
IndexExprsLatticeStorageto storevectorShapeand propagate/join it together with index expressions (including symbol filtering). - Stop forward/backward propagation through
wave.mma(index exprs stay as initialized by MMA kind), and special-case index attribute materialization for MMA ops. - Update test override initialization parsing to support optional vector-shape overrides, and expand/adjust MLIR tests accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | Implements vectorShape storage, filtering, and join logic in the index-expr lattice. |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.h | Extends lattice API/state to include vectorShape and related helpers. |
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Disables MMA index-expr propagation; initializes MMA/Write lattices while joining in vectorShape; preserves vectorShape across permutes. |
| water/lib/Dialect/Wave/Transforms/InferTypes.cpp | Adds safeSet direction checks; threads vectorShape through override initialization; special-cases MMA index attribute emission. |
| water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h | Updates override-initialization callback signature to include vectorShape. |
| water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp | Updates test override parsing to support multiple formats including optional vectorShape. |
| water/test/Dialect/Wave/infer-index-exprs.mlir | Renames/extends MMA-chain coverage and adds a test ensuring no propagation through chained MMAs. |
| water/test/Dialect/Wave/infer-index-exprs-lattice.mlir | Reworks lattice tests and adds new vector-shape join/conflict coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| // Override the value-to-store operand with the index expression implied by | ||
| // write to avoid sideways propagation. This will hit an error during | ||
| // initializaiton rather than during propagation. |
There was a problem hiding this comment.
Typo in comment: “initializaiton” should be “initialization”.
| // initializaiton rather than during propagation. | |
| // initialization rather than during propagation. |
| if (lhs) { | ||
| Attribute lhsValue = lhs.get(attr.getName()); | ||
| if (lhsValue && lhsValue != attr.getValue()) | ||
| return nullptr; | ||
| continue; | ||
| } |
There was a problem hiding this comment.
getJoinedVectorShape never appends RHS-only entries because the if (lhs) branch is always taken once both dictionaries are non-null; this causes vector-shape information from rhs to be dropped even when there is no conflict. Fix by checking whether lhs contains the current key and only skipping/validating in that case; otherwise append the RHS attribute to the joined dictionary.
| if (lhs) { | |
| Attribute lhsValue = lhs.get(attr.getName()); | |
| if (lhsValue && lhsValue != attr.getValue()) | |
| return nullptr; | |
| continue; | |
| } | |
| // If LHS already has this symbol, ensure the values are equal. | |
| Attribute lhsValue = lhs.get(attr.getName()); | |
| if (lhsValue) { | |
| if (lhsValue != attr.getValue()) | |
| return nullptr; | |
| // Equal value already present from LHS, no need to add RHS duplicate. | |
| continue; | |
| } | |
| // Symbol is only present on RHS (and not ignored), so append it. |
| // mapping symbol names to vector dimension sizes. Two concrete lattice values | ||
| // with different vector shapes cannot be joined and will result in top. |
There was a problem hiding this comment.
The comment says “Two concrete lattice values with different vector shapes cannot be joined and will result in top”, but the implementation of vector-shape joining in IndexExprsLatticeStorage::join factors in lattice priority and can select one side instead of producing top. Please align this comment with the actual semantics (or adjust the join semantics if priority is not intended to apply to vector shapes).
| // mapping symbol names to vector dimension sizes. Two concrete lattice values | |
| // with different vector shapes cannot be joined and will result in top. | |
| // mapping symbol names to vector dimension sizes. In general, two concrete | |
| // lattice values with different vector shapes are considered incompatible and | |
| // their join will yield top, unless the lattice priority causes one side to | |
| // override the other during the join. |
| assert(constraints && "constraints not found"); | ||
| if (failed(mma.initializeIndexExprsBackward( | ||
| operandLattices, | ||
| *wave::IndexExprsAnalysisInit::create( | ||
| mma->getLoc(), constraints, | ||
| wave::getHyperparameters(mma)), |
There was a problem hiding this comment.
In the MMA special-case path, constraints is required but is only enforced via assert(constraints && "constraints not found"), and then *IndexExprsAnalysisInit::create(...) is dereferenced unconditionally. In release builds, this can crash if an MMA is encountered outside a parent op with wave.constraints. Please replace the assert with a proper check that emits an error and interrupts the walk when constraints are missing (and avoid dereferencing the create(...) result without checking it succeeded).
| assert(constraints && "constraints not found"); | |
| if (failed(mma.initializeIndexExprsBackward( | |
| operandLattices, | |
| *wave::IndexExprsAnalysisInit::create( | |
| mma->getLoc(), constraints, | |
| wave::getHyperparameters(mma)), | |
| if (!constraints) { | |
| mma->emitError() | |
| << "missing '" << wave::WaveDialect::kWaveConstraintsAttrName | |
| << "' attribute on a parent operation required for MMA index " | |
| "expression inference"; | |
| return WalkResult::interrupt(); | |
| } | |
| auto analysisInit = wave::IndexExprsAnalysisInit::create( | |
| mma->getLoc(), constraints, wave::getHyperparameters(mma)); | |
| if (!analysisInit) { | |
| mma->emitError() | |
| << "failed to initialize index expression analysis for MMA " | |
| "operation"; | |
| return WalkResult::interrupt(); | |
| } | |
| if (failed(mma.initializeIndexExprsBackward( | |
| operandLattices, *analysisInit, |
| << "expected " << attributeName | ||
| << " to be an array of " | ||
| "dictionaries with WaveIndexMappingAttr or UnitAttr values"; |
There was a problem hiding this comment.
The override parsing error message is now misleading/inconsistent with the accepted formats: this code accepts a bare DictionaryAttr (not just an array), and it currently rejects UnitAttr values inside the indexExprs dictionary (the lambda only permits WaveIndexMappingAttr). Please update the diagnostic text to reflect the actual accepted formats/values, or relax the validation if UnitAttr values are intended to be allowed in the dictionary.
| << "expected " << attributeName | |
| << " to be an array of " | |
| "dictionaries with WaveIndexMappingAttr or UnitAttr values"; | |
| << "expected each element of " << attributeName | |
| << " to be either a DictionaryAttr or an array " | |
| "[priority?, indexExprs, vectorShape?], where indexExprs is " | |
| "a DictionaryAttr whose values are WaveIndexMappingAttr"; |
| if (joined.isTop()) { | ||
| InFlightDiagnostic diag = emitError() | ||
| << "conflict for " << latticeName | ||
| << " index expression when propagating from " | ||
| << otherName << " lattice"; | ||
| diag.attachNote() << "original " << latticeName << " lattice: " << lattice; | ||
| diag.attachNote() << otherName << " lattice: " << other; | ||
| return diag; |
There was a problem hiding this comment.
joinIndexExprsLatticeInPlace reports any join-to-top as an “index expression” conflict, but IndexExprsLatticeStorage::join can also return top due to a vector-shape conflict. This can make diagnostics confusing. Consider detecting vector-shape conflicts (e.g., via IndexExprsLatticeStorage::hasVectorShapeConflict) and tailoring the error/notes to mention the vector-shape mismatch (and include the vectorShape values).
These two quantities are related: vector shapes indicate the (desired)
shape of the data processed by each operation instance after
distribution and unrolling, index expressions the position of the that
in the larger tensor. In some cases, they may influence each other
during propagation. In future, we should consider folding them into one
data structure but for now be consistent with pywave.