Skip to content

[water] propagate vector shapes along index expressions#1160

Open
ftynse wants to merge 4 commits intomainfrom
users/ftynse/vector-shape
Open

[water] propagate vector shapes along index expressions#1160
ftynse wants to merge 4 commits intomainfrom
users/ftynse/vector-shape

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Mar 20, 2026

These two quantities are related: vector shapes indicate the (desired)
shape of the data processed by each operation instance after
distribution and unrolling, index expressions the position of the that
in the larger tensor. In some cases, they may influence each other
during propagation. In future, we should consider folding them into one
data structure but for now be consistent with pywave.

ftynse added 4 commits March 20, 2026 19:00
Several places in the index expression analysis have been directly
overriding pre-existing lattice values, in particular for:

 - block arguments during forward initialization, with bottom;
 - terminator operands during backward initialization, with bottom;
 - op operands during backward initialization, with the value returned
   by the interface call.

This may and effectively does move back on the lattice, which may cause
problems with analysis convergence or consistency. Furthermore, since
the lattice are shared between forward and backward analyses,
initialization of one of them may override the previously set lattice
value.

The two former cases were attempts to defend against the awkward
upstream setup where lattices are initialized to the top value for
block arguments during per-op initialization. This isn't needed since we
have a mechanism inside `setToEntry/ExitState` to avoid that and
initialize to bottom during per-op initizliation. Drop that.

The latter case is due to interface implementations overriding the
pre-existing values instead of joining them. While this may be justified
in some cases, we need to ensure the lattice doesn't go back from values
already available. More generally, initialization triggers some
propagation and so does lattice sharing between analyses. Therefore,
initialization must also join with the pre-existing values and
potentially report errors. Guard against lattice direction changes in
debug mode.

Make sure error messages are emitted if joining lattices in
initialization results in the top state. Some of these are difficult to
trigger directly without complex test scaffolding, so we make an
intentional decision to only test a subset of messages. In particular,
initalization for backward pass will run after (1)
default-initialization for forward that visits all operations; (2)
initializaiton for forward; (3) default-initialization for backward that
visits all operations again; and triggerring an error message means none
of the above should have detected the conflict first. A potential
solution to that is to meld initalization and visitation functions, but
it comes with higher code complexity of visitation functions and
potentially higher cost.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
MMA ops require specific index expressions set during initialization of
the analysis. We shouldn't allow these values to be overwritten, even
with indexes defined by other MMA ops to avoid indexing and hit the
right instruction. If the index for an MMA operand specified at its
definition point differs from what MMA expects, a yet-to-be-ported
pywave pass will introduce a reshape operation.

A more holistic solution would inject "virtual reshape" operations and
update the propagation to work with sets of all possible index
expressions rather than immediately detecting conflicts. A separate
analysis would then choose between the possible expressions based on
some criteria, materialize meaningful reshapes and drop the noop ones.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Assited-by: Claude Opus 4.6 <noreply@antropic.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
These two quantities are related: vector shapes indicate the (desired)
shape of the data processed by each operation instance after
distribution and unrolling, index expressions the position of the that
in the larger tensor. In some cases, they may influence each other
during propagation. In future, we should consider folding them into one
data structure but for now be consistent with pywave.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse requested a review from tgymnich March 20, 2026 18:27
@github-actions
Copy link

Water Code Coverage

Filename                                                           Functions  Missed Functions  Executed       Lines      Missed Lines     Cover    Branches   Missed Branches     Cover
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
lib/Transforms/MemrefDecomposition.cpp                                    28                 0   100.00%         600                49    91.83%         104                46    55.77%
lib/Transforms/AllocToAlloca.cpp                                           2                 0   100.00%          17                 0   100.00%           0                 0         -
lib/Transforms/CheckStaticAssertions.cpp                                   2                 0   100.00%          22                 1    95.45%           8                 4    50.00%
lib/Transforms/GPUModuleToBinary.cpp                                      19                 5    73.68%         339               115    66.08%         128                57    55.47%
lib/Transforms/DropTransformOps.cpp                                        2                 0   100.00%          16                 0   100.00%           2                 0   100.00%
lib/Transforms/GPUToGPURuntime.cpp                                        14                 0   100.00%         298                23    92.28%          40                17    57.50%
lib/Transforms/SLPVectorizer.cpp                                          61                 3    95.08%        1065                99    90.70%         558               167    70.07%
lib/Transforms/AccessCheckers.cpp                                         35                 1    97.14%         446                40    91.03%         124                30    75.81%
lib/Transforms/AssembleISA.cpp                                             4                 1    75.00%          30                 2    93.33%           2                 1    50.00%
lib/Dialect/Wave/Transforms/LoweringPatterns.cpp                          45                 2    95.56%         913               141    84.56%         262                78    70.23%
lib/Dialect/Wave/Transforms/PropagateDefaultsFromConstraints.cpp           3                 3     0.00%          35                35     0.00%          12                12     0.00%
lib/Dialect/Wave/Transforms/TypeConverter.cpp                              7                 2    71.43%          96                26    72.92%          32                17    46.88%
lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp                         10                 0   100.00%         238                18    92.44%          58                11    81.03%
lib/Dialect/Wave/Transforms/DetectNormalForms.cpp                          4                 0   100.00%          48                 0   100.00%           8                 0   100.00%
lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp                   2                 0   100.00%          23                 1    95.65%           6                 1    83.33%
lib/Dialect/Wave/Transforms/InferTypes.cpp                               106                14    86.79%        1750               144    91.77%         848               435    48.70%
lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp                            5                 0   100.00%         130                 1    99.23%          16                 2    87.50%
lib/Dialect/Wave/Transforms/Utils.cpp                                      6                 0   100.00%          96                 5    94.79%          26                 4    84.62%
lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp              7                 0   100.00%         183                16    91.26%          32                14    56.25%
lib/Dialect/Wave/IR/WaveOps.cpp                                          129                15    88.37%        2256               258    88.56%         936               193    79.38%
lib/Dialect/Wave/IR/WaveAttrs.cpp                                         72                 7    90.28%         921                95    89.69%         408                60    85.29%
lib/Dialect/Wave/IR/IndexExpr.cpp                                         11                 0   100.00%         119                 1    99.16%          24                 3    87.50%
lib/Dialect/Wave/IR/WaveDialect.cpp                                       13                 0   100.00%         440                11    97.50%         150                 7    95.33%
lib/Dialect/Wave/IR/WaveTypes.cpp                                          9                 1    88.89%          75                 8    89.33%          18                 3    83.33%
lib/Dialect/Wave/IR/WaveInterfaces.cpp                                    88                 6    93.18%        1265                87    93.12%         548                84    84.67%
lib/Dialect/Wave/IR/WaveUtils.cpp                                          7                 0   100.00%         129                 7    94.57%          52                10    80.77%
lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp                3                 0   100.00%          34                 6    82.35%           8                 2    75.00%
lib/Dialect/NormalForm/IR/NormalFormDialect.cpp                            1                 0   100.00%           6                 0   100.00%           0                 0         -
lib/Dialect/NormalForm/IR/NormalFormOps.cpp                               12                 0   100.00%         201                 9    95.52%          58                 7    87.93%
lib/Pipelines/Pipelines.cpp                                                2                 0   100.00%          27                 0   100.00%           0                 0         -
lib/Analysis/InUseForSpeculation.cpp                                      12                 1    91.67%         142                 8    94.37%          32                 4    87.50%
include/water/Dialect/Wave/Transforms/LoweringPatterns.h                   1                 0   100.00%           3                 0   100.00%           0                 0         -
include/water/Dialect/Wave/IR/IndexExpr.h                                  1                 0   100.00%          10                 0   100.00%           2                 0   100.00%
include/water/Dialect/Wave/IR/WaveInterfaces.h                            38                 3    92.11%         148                 8    94.59%           8                 2    75.00%
include/water/Dialect/Wave/IR/WaveTypes.h                                  1                 0   100.00%           5                 0   100.00%           4                 0   100.00%
include/water/Dialect/Wave/IR/WaveUtils.h                                  1                 0   100.00%           5                 1    80.00%           4                 2    50.00%
include/water/Dialect/Wave/IR/WaveAttrs.h                                  4                 0   100.00%          14                 0   100.00%           0                 0         -
include/water/Dialect/NormalForm/IR/NormalFormInterfaces.h                 1                 1     0.00%           4                 4     0.00%           0                 0         -
include/water/Analysis/InUseForSpeculation.h                              12                 3    75.00%          39                17    56.41%          16                10    37.50%
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TOTAL                                                                    780                68    91.28%       12188              1236    89.86%        4534              1283    71.70%

Download full HTML report

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends Wave index-expression dataflow to also carry and join per-dimension vector shapes alongside index expressions, and adjusts MMA behavior to keep MMA-specific index expressions stable (for pywave compatibility).

Changes:

  • Extend IndexExprsLatticeStorage to store vectorShape and propagate/join it together with index expressions (including symbol filtering).
  • Stop forward/backward propagation through wave.mma (index exprs stay as initialized by MMA kind), and special-case index attribute materialization for MMA ops.
  • Update test override initialization parsing to support optional vector-shape overrides, and expand/adjust MLIR tests accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
water/lib/Dialect/Wave/IR/WaveInterfaces.cpp Implements vectorShape storage, filtering, and join logic in the index-expr lattice.
water/include/water/Dialect/Wave/IR/WaveInterfaces.h Extends lattice API/state to include vectorShape and related helpers.
water/lib/Dialect/Wave/IR/WaveOps.cpp Disables MMA index-expr propagation; initializes MMA/Write lattices while joining in vectorShape; preserves vectorShape across permutes.
water/lib/Dialect/Wave/Transforms/InferTypes.cpp Adds safeSet direction checks; threads vectorShape through override initialization; special-cases MMA index attribute emission.
water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h Updates override-initialization callback signature to include vectorShape.
water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp Updates test override parsing to support multiple formats including optional vectorShape.
water/test/Dialect/Wave/infer-index-exprs.mlir Renames/extends MMA-chain coverage and adds a test ensuring no propagation through chained MMAs.
water/test/Dialect/Wave/infer-index-exprs-lattice.mlir Reworks lattice tests and adds new vector-shape join/conflict coverage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


// Override the value-to-store operand with the index expression implied by
// write to avoid sideways propagation. This will hit an error during
// initializaiton rather than during propagation.
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: “initializaiton” should be “initialization”.

Suggested change
// initializaiton rather than during propagation.
// initialization rather than during propagation.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines +1287 to +1292
if (lhs) {
Attribute lhsValue = lhs.get(attr.getName());
if (lhsValue && lhsValue != attr.getValue())
return nullptr;
continue;
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getJoinedVectorShape never appends RHS-only entries because the if (lhs) branch is always taken once both dictionaries are non-null; this causes vector-shape information from rhs to be dropped even when there is no conflict. Fix by checking whether lhs contains the current key and only skipping/validating in that case; otherwise append the RHS attribute to the joined dictionary.

Suggested change
if (lhs) {
Attribute lhsValue = lhs.get(attr.getName());
if (lhsValue && lhsValue != attr.getValue())
return nullptr;
continue;
}
// If LHS already has this symbol, ensure the values are equal.
Attribute lhsValue = lhs.get(attr.getName());
if (lhsValue) {
if (lhsValue != attr.getValue())
return nullptr;
// Equal value already present from LHS, no need to add RHS duplicate.
continue;
}
// Symbol is only present on RHS (and not ignored), so append it.

Copilot uses AI. Check for mistakes.
Comment on lines +668 to +669
// mapping symbol names to vector dimension sizes. Two concrete lattice values
// with different vector shapes cannot be joined and will result in top.
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says “Two concrete lattice values with different vector shapes cannot be joined and will result in top”, but the implementation of vector-shape joining in IndexExprsLatticeStorage::join factors in lattice priority and can select one side instead of producing top. Please align this comment with the actual semantics (or adjust the join semantics if priority is not intended to apply to vector shapes).

Suggested change
// mapping symbol names to vector dimension sizes. Two concrete lattice values
// with different vector shapes cannot be joined and will result in top.
// mapping symbol names to vector dimension sizes. In general, two concrete
// lattice values with different vector shapes are considered incompatible and
// their join will yield top, unless the lattice priority causes one side to
// override the other during the join.

Copilot uses AI. Check for mistakes.
Comment on lines +2022 to +2027
assert(constraints && "constraints not found");
if (failed(mma.initializeIndexExprsBackward(
operandLattices,
*wave::IndexExprsAnalysisInit::create(
mma->getLoc(), constraints,
wave::getHyperparameters(mma)),
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the MMA special-case path, constraints is required but is only enforced via assert(constraints && "constraints not found"), and then *IndexExprsAnalysisInit::create(...) is dereferenced unconditionally. In release builds, this can crash if an MMA is encountered outside a parent op with wave.constraints. Please replace the assert with a proper check that emits an error and interrupts the walk when constraints are missing (and avoid dereferencing the create(...) result without checking it succeeded).

Suggested change
assert(constraints && "constraints not found");
if (failed(mma.initializeIndexExprsBackward(
operandLattices,
*wave::IndexExprsAnalysisInit::create(
mma->getLoc(), constraints,
wave::getHyperparameters(mma)),
if (!constraints) {
mma->emitError()
<< "missing '" << wave::WaveDialect::kWaveConstraintsAttrName
<< "' attribute on a parent operation required for MMA index "
"expression inference";
return WalkResult::interrupt();
}
auto analysisInit = wave::IndexExprsAnalysisInit::create(
mma->getLoc(), constraints, wave::getHyperparameters(mma));
if (!analysisInit) {
mma->emitError()
<< "failed to initialize index expression analysis for MMA "
"operation";
return WalkResult::interrupt();
}
if (failed(mma.initializeIndexExprsBackward(
operandLattices, *analysisInit,

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +87
<< "expected " << attributeName
<< " to be an array of "
"dictionaries with WaveIndexMappingAttr or UnitAttr values";
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The override parsing error message is now misleading/inconsistent with the accepted formats: this code accepts a bare DictionaryAttr (not just an array), and it currently rejects UnitAttr values inside the indexExprs dictionary (the lambda only permits WaveIndexMappingAttr). Please update the diagnostic text to reflect the actual accepted formats/values, or relax the validation if UnitAttr values are intended to be allowed in the dictionary.

Suggested change
<< "expected " << attributeName
<< " to be an array of "
"dictionaries with WaveIndexMappingAttr or UnitAttr values";
<< "expected each element of " << attributeName
<< " to be either a DictionaryAttr or an array "
"[priority?, indexExprs, vectorShape?], where indexExprs is "
"a DictionaryAttr whose values are WaveIndexMappingAttr";

Copilot uses AI. Check for mistakes.
Comment on lines +1089 to +1096
if (joined.isTop()) {
InFlightDiagnostic diag = emitError()
<< "conflict for " << latticeName
<< " index expression when propagating from "
<< otherName << " lattice";
diag.attachNote() << "original " << latticeName << " lattice: " << lattice;
diag.attachNote() << otherName << " lattice: " << other;
return diag;
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

joinIndexExprsLatticeInPlace reports any join-to-top as an “index expression” conflict, but IndexExprsLatticeStorage::join can also return top due to a vector-shape conflict. This can make diagnostics confusing. Consider detecting vector-shape conflicts (e.g., via IndexExprsLatticeStorage::hasVectorShapeConflict) and tailoring the error/notes to mention the vector-shape mismatch (and include the vectorShape values).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants