@@ -380,15 +380,18 @@ class IndexExprsAnalysisInit {
380380
381381// Lattice for propagating index expressions across wave dialect operations.
382382// In addition to the bottom and top states, it can represent a concrete state
383- // manifested as a dictionary attribute mapping symbol names to index mappings.
383+ // manifested as a WaveIndexExprsAttr mapping dimension symbols to index
384+ // mappings. The entries preserve order, corresponding to the tensor type's
385+ // dimension order.
386+ //
384387// The JOIN function is defined similarly to other lattices with special
385388// handling for combining thread-dependent and thread-independent index
386389// expressions.
387390class IndexExprsLatticeStorage {
388391public:
389392 IndexExprsLatticeStorage ();
390393 IndexExprsLatticeStorage (const IndexExprsLatticeStorage &value) = default ;
391- IndexExprsLatticeStorage (mlir::DictionaryAttr concreteValue);
394+ IndexExprsLatticeStorage (wave::WaveIndexExprsAttr concreteValue);
392395
393396 IndexExprsLatticeStorage &
394397 operator =(const IndexExprsLatticeStorage &other) = default ;
@@ -404,15 +407,38 @@ class IndexExprsLatticeStorage {
404407
405408 // Returns the concrete value stored in the lattice instance, be it fully
406409 // specified or not, or null if the lattice instance is a top or a bottom.
407- mlir::DictionaryAttr getConcreteValue () const ;
410+ wave::WaveIndexExprsAttr getConcreteValue () const ;
408411
409412 // Return the top lattice instance.
410413 static IndexExprsLatticeStorage top ();
411414
412415 // Return the bottom lattice instance.
413416 static IndexExprsLatticeStorage bottom ();
414417
415- // Join two lattice instances and return the result.
418+ // / Join two lattice instances and return the result.
419+ // /
420+ // / Ordering semantics:
421+ // / - LHS entries come first (in LHS order), then RHS-only entries (in RHS
422+ // / order).
423+ // / - Entries with the same dimension have their mappings merged.
424+ // /
425+ // / Valid usage scenarios:
426+ // / 1. **Same dimensions, same order**: Both LHS and RHS have matching
427+ // / dimension order (e.g., both have {M, K}). The result preserves this
428+ // / order. This is the common case for elementwise ops like wave.add.
429+ // /
430+ // / 2. **MMA ops**: LHS has {M, K}, RHS has {N, K}, accumulator/result has
431+ // / {M, N}. The `ignoredRhsSymbols` parameter filters dimensions that
432+ // / don't apply (e.g., ignore M when propagating from LHS to result).
433+ // / Joins are done incrementally: start with bottom, join LHS (filtered),
434+ // / then join RHS, then join accumulator.
435+ // /
436+ // / 3. **Iterate ops**: Block arguments are joined with iter_args, and
437+ // / terminator operands with results. Both should have matching tensor
438+ // / types and thus matching dimension order.
439+ // /
440+ // / If LHS and RHS have conflicting mappings for the same dimension (i.e.,
441+ // / mappings that cannot be merged), the result is `top` (conflict).
416442 static IndexExprsLatticeStorage
417443 join (const IndexExprsLatticeStorage &lhs, const IndexExprsLatticeStorage &rhs,
418444 llvm::ArrayRef<mlir::Attribute> ignoredRhsSymbols = {});
@@ -432,13 +458,14 @@ class IndexExprsLatticeStorage {
432458 // state.
433459 void unsafeSet (const IndexExprsLatticeStorage &value);
434460
435- // Return a new lattice instance with only the provided symbols present.
461+ // Return a new lattice instance with only the provided symbols present,
462+ // preserving the current order (filtering only, no reordering).
436463 IndexExprsLatticeStorage
437464 keepOnlySymbols (llvm::ArrayRef<wave::WaveSymbolAttr> symbols) const ;
438465
439466 // Return a new lattice instance where all expressions no longer have
440467 // references to the provided iterator symbols. Note that this doesn't remove
441- // elements from the mapping dictionary but updates the mapped expressions.
468+ // elements from the mapping but updates the mapped expressions.
442469 IndexExprsLatticeStorage
443470 withoutIterSymbols (llvm::ArrayRef<wave::WaveSymbolAttr> iterSymbols) const ;
444471
@@ -447,8 +474,8 @@ class IndexExprsLatticeStorage {
447474 LLVM_DUMP_METHOD void dump () const ;
448475
449476private:
450- // The internal storage is either a dictionary attribute with one entry per
451- // symbol indexing the value or one of the top/bottom flags.
477+ // The internal storage is either a WaveIndexExprsAttr with ordered entries
478+ // per dimension symbol, or one of the top/bottom flags.
452479 llvm::PointerIntPair<mlir::Attribute, 2 > value;
453480
454481 // State flags.
@@ -458,6 +485,7 @@ class IndexExprsLatticeStorage {
458485};
459486
460487void operator <<(mlir::Diagnostic &diag, const IndexExprsLatticeStorage &value);
488+
461489} // namespace wave
462490
463491namespace llvm {
0 commit comments