Skip to content

Commit 746aef2

Browse files
committed
[water] Migrate from DictionaryAttr to WaveIndexExprsAttr for index expressions
This commit uses WaveIndexExprsAttr and WaveIndexEntryAttr to replace DictionaryAttr for storing index expressions. The key motivation is that DictionaryAttr alphabetically sorts its entries, but dimension order must be preserved to match the tensor type's shape order for correct lowering. Changes: - Custom parsing/printing that preserves order - Update IndexExprsLatticeStorage to use WaveIndexExprsAttr - Update join() logic to use MapVector for order preservation - Add C API and Python bindings - Update Python emitter/converter The join() logic uses LHS-first ordering policy: LHS entries come first (in LHS order), then new RHS entries (in RHS order). This ensures deterministic output when joining lattices. Signed-off-by: tyb0807 <[email protected]>
1 parent 6b80129 commit 746aef2

File tree

21 files changed

+844
-405
lines changed

21 files changed

+844
-405
lines changed

water/include/water/Dialect/Wave/IR/WaveAttrs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,4 +589,13 @@ def WaveIndexExprsAttr : AttrDef<WaveDialect, "WaveIndexExprs"> {
589589
}];
590590
}
591591

592+
//-----------------------------------------------------------------------------
593+
// Typed array attributes
594+
//-----------------------------------------------------------------------------
595+
596+
def WaveIndexExprsArrayAttr : TypedArrayAttrBase<WaveIndexExprsAttr,
597+
"array of WaveIndexExprsAttr"> {
598+
let constBuilderCall = "$_builder.getArrayAttr($0)";
599+
}
600+
592601
#endif // WATER_DIALECT_WAVE_WAVEATTRS

water/include/water/Dialect/Wave/IR/WaveInterfaces.h

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,18 @@ class IndexExprsAnalysisInit {
380380

381381
// Lattice for propagating index expressions across wave dialect operations.
382382
// In addition to the bottom and top states, it can represent a concrete state
383-
// manifested as a dictionary attribute mapping symbol names to index mappings.
383+
// manifested as a WaveIndexExprsAttr mapping dimension symbols to index
384+
// mappings. The entries preserve order, corresponding to the tensor type's
385+
// dimension order.
386+
//
384387
// The JOIN function is defined similarly to other lattices with special
385388
// handling for combining thread-dependent and thread-independent index
386389
// expressions.
387390
class IndexExprsLatticeStorage {
388391
public:
389392
IndexExprsLatticeStorage();
390393
IndexExprsLatticeStorage(const IndexExprsLatticeStorage &value) = default;
391-
IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue);
394+
IndexExprsLatticeStorage(wave::WaveIndexExprsAttr concreteValue);
392395

393396
IndexExprsLatticeStorage &
394397
operator=(const IndexExprsLatticeStorage &other) = default;
@@ -404,15 +407,38 @@ class IndexExprsLatticeStorage {
404407

405408
// Returns the concrete value stored in the lattice instance, be it fully
406409
// specified or not, or null if the lattice instance is a top or a bottom.
407-
mlir::DictionaryAttr getConcreteValue() const;
410+
wave::WaveIndexExprsAttr getConcreteValue() const;
408411

409412
// Return the top lattice instance.
410413
static IndexExprsLatticeStorage top();
411414

412415
// Return the bottom lattice instance.
413416
static IndexExprsLatticeStorage bottom();
414417

415-
// Join two lattice instances and return the result.
418+
/// Join two lattice instances and return the result.
419+
///
420+
/// Ordering semantics:
421+
/// - LHS entries come first (in LHS order), then RHS-only entries (in RHS
422+
/// order).
423+
/// - Entries with the same dimension have their mappings merged.
424+
///
425+
/// Valid usage scenarios:
426+
/// 1. **Same dimensions, same order**: Both LHS and RHS have matching
427+
/// dimension order (e.g., both have {M, K}). The result preserves this
428+
/// order. This is the common case for elementwise ops like wave.add.
429+
///
430+
/// 2. **MMA ops**: LHS has {M, K}, RHS has {N, K}, accumulator/result has
431+
/// {M, N}. The `ignoredRhsSymbols` parameter filters dimensions that
432+
/// don't apply (e.g., ignore M when propagating from LHS to result).
433+
/// Joins are done incrementally: start with bottom, join LHS (filtered),
434+
/// then join RHS, then join accumulator.
435+
///
436+
/// 3. **Iterate ops**: Block arguments are joined with iter_args, and
437+
/// terminator operands with results. Both should have matching tensor
438+
/// types and thus matching dimension order.
439+
///
440+
/// If LHS and RHS have conflicting mappings for the same dimension (i.e.,
441+
/// mappings that cannot be merged), the result is `top` (conflict).
416442
static IndexExprsLatticeStorage
417443
join(const IndexExprsLatticeStorage &lhs, const IndexExprsLatticeStorage &rhs,
418444
llvm::ArrayRef<mlir::Attribute> ignoredRhsSymbols = {});
@@ -432,13 +458,14 @@ class IndexExprsLatticeStorage {
432458
// state.
433459
void unsafeSet(const IndexExprsLatticeStorage &value);
434460

435-
// Return a new lattice instance with only the provided symbols present.
461+
// Return a new lattice instance with only the provided symbols present,
462+
// preserving the current order (filtering only, no reordering).
436463
IndexExprsLatticeStorage
437464
keepOnlySymbols(llvm::ArrayRef<wave::WaveSymbolAttr> symbols) const;
438465

439466
// Return a new lattice instance where all expressions no longer have
440467
// references to the provided iterator symbols. Note that this doesn't remove
441-
// elements from the mapping dictionary but updates the mapped expressions.
468+
// elements from the mapping but updates the mapped expressions.
442469
IndexExprsLatticeStorage
443470
withoutIterSymbols(llvm::ArrayRef<wave::WaveSymbolAttr> iterSymbols) const;
444471

@@ -447,8 +474,8 @@ class IndexExprsLatticeStorage {
447474
LLVM_DUMP_METHOD void dump() const;
448475

449476
private:
450-
// The internal storage is either a dictionary attribute with one entry per
451-
// symbol indexing the value or one of the top/bottom flags.
477+
// The internal storage is either a WaveIndexExprsAttr with ordered entries
478+
// per dimension symbol, or one of the top/bottom flags.
452479
llvm::PointerIntPair<mlir::Attribute, 2> value;
453480

454481
// State flags.
@@ -458,6 +485,7 @@ class IndexExprsLatticeStorage {
458485
};
459486

460487
void operator<<(mlir::Diagnostic &diag, const IndexExprsLatticeStorage &value);
488+
461489
} // namespace wave
462490

463491
namespace llvm {

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def WaveMemoryType : Type<Or<[WaveTensorInMemory.predicate, AnyMemRef.predicate]
4040
class WaveOp<string mnemonic, list<Trait> traits = []>
4141
: Op<WaveDialect, mnemonic, !listconcat(traits, [HasWaveIndexMapping])> {
4242
dag commonArguments = (ins
43-
Arg<OptionalAttr<DictArrayAttr>, "Index expression">:$index
43+
Arg<OptionalAttr<WaveIndexExprsArrayAttr>, "Index expression">:$index
4444
);
4545

4646
string commonArgumentsSyntax = "( `index` custom<WaveIndexDict>($index)^ )?";

water/include/water/Dialect/Wave/IR/WaveUtils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ namespace wave {
2121
/// Return the position of the dimension that is vectorized based on the index
2222
/// sequence. The dimension with the largest step is considered to be
2323
/// vectorized. In case of a tie, take the dimension that is farther in the
24-
/// index dictionary, which is secretly a list. Return failure when the index
25-
/// sequence step cannot be evaluated statically.
24+
/// index expressions. Return failure when the index sequence step cannot be
25+
/// evaluated statically.
2626
std::optional<uint64_t>
2727
getPositionOfVectorizedDim(llvm::ArrayRef<wave::WaveSymbolAttr> shape,
28-
mlir::DictionaryAttr indexDict,
28+
wave::WaveIndexExprsAttr indexExprs,
2929
wave::WaveHyperparameterAttr hyper);
3030

3131
// Return the vector shape implied by the index sequence and hyperparameteters,
@@ -34,7 +34,7 @@ getPositionOfVectorizedDim(llvm::ArrayRef<wave::WaveSymbolAttr> shape,
3434
// it cannot be fully evaluated.
3535
llvm::SmallVector<int64_t>
3636
getUncollapsedVectorShape(llvm::ArrayRef<wave::WaveSymbolAttr> shape,
37-
mlir::DictionaryAttr indexDict,
37+
wave::WaveIndexExprsAttr indexExprs,
3838
wave::WaveHyperparameterAttr hyper);
3939

4040
/// Resolve named Wave symbols to concrete integer values using the

water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ class DataFlowSolver;
1818
class SymbolTableCollection;
1919
class Operation;
2020
class Value;
21-
class DictionaryAttr;
2221
} // namespace mlir
2322

2423
namespace wave {
24+
class WaveIndexExprsAttr;
25+
2526
using SetIndexLatticeFn =
26-
llvm::function_ref<void(mlir::Value, mlir::DictionaryAttr)>;
27+
llvm::function_ref<void(mlir::Value, wave::WaveIndexExprsAttr)>;
2728
using OverrideInitializationFn = llvm::function_ref<llvm::LogicalResult(
2829
mlir::Operation *, SetIndexLatticeFn)>;
2930

water/include/water/c/Dialects.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,51 @@ mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr);
131131
MLIR_CAPI_EXPORTED MlirAttribute
132132
mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr, intptr_t index);
133133

134+
//===---------------------------------------------------------------------===//
135+
// WaveIndexEntryAttr
136+
//===---------------------------------------------------------------------===//
137+
138+
/// Checks whether the given MLIR attribute is a WaveIndexEntryAttr.
139+
MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveIndexEntryAttr(MlirAttribute attr);
140+
141+
/// Creates a new WaveIndexEntryAttr with the given dimension symbol and
142+
/// mapping.
143+
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexEntryAttrGet(
144+
MlirContext mlirCtx, MlirAttribute dimension, MlirAttribute mapping);
145+
146+
/// Returns the typeID of a WaveIndexEntryAttr.
147+
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexEntryAttrGetTypeID();
148+
149+
/// Gets the dimension symbol from a WaveIndexEntryAttr.
150+
MLIR_CAPI_EXPORTED MlirAttribute
151+
mlirWaveIndexEntryAttrGetDimension(MlirAttribute attr);
152+
153+
/// Gets the mapping from a WaveIndexEntryAttr.
154+
MLIR_CAPI_EXPORTED MlirAttribute
155+
mlirWaveIndexEntryAttrGetMapping(MlirAttribute attr);
156+
157+
//===---------------------------------------------------------------------===//
158+
// WaveIndexExprsAttr
159+
//===---------------------------------------------------------------------===//
160+
161+
/// Checks whether the given MLIR attribute is a WaveIndexExprsAttr.
162+
MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveIndexExprsAttr(MlirAttribute attr);
163+
164+
/// Creates a new WaveIndexExprsAttr with the given list of entries.
165+
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexExprsAttrGet(
166+
MlirContext mlirCtx, intptr_t numEntries, MlirAttribute *entries);
167+
168+
/// Returns the typeID of a WaveIndexExprsAttr.
169+
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexExprsAttrGetTypeID();
170+
171+
/// Gets the number of entries in a WaveIndexExprsAttr.
172+
MLIR_CAPI_EXPORTED intptr_t
173+
mlirWaveIndexExprsAttrGetNumEntries(MlirAttribute attr);
174+
175+
/// Gets the entry at the given index from a WaveIndexExprsAttr.
176+
MLIR_CAPI_EXPORTED MlirAttribute
177+
mlirWaveIndexExprsAttrGetEntry(MlirAttribute attr, intptr_t index);
178+
134179
//===---------------------------------------------------------------------===//
135180
// WaveHyperparameterAttr
136181
//===---------------------------------------------------------------------===//

water/lib/CAPI/Dialects.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,71 @@ MlirAttribute mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr,
164164
llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getSymbols()[index]);
165165
}
166166

167+
//===---------------------------------------------------------------------===//
168+
// WaveIndexEntryAttr
169+
//===---------------------------------------------------------------------===//
170+
171+
bool mlirAttributeIsAWaveIndexEntryAttr(MlirAttribute attr) {
172+
return llvm::isa<wave::WaveIndexEntryAttr>(unwrap(attr));
173+
}
174+
175+
MlirAttribute mlirWaveIndexEntryAttrGet(MlirContext mlirCtx,
176+
MlirAttribute dimension,
177+
MlirAttribute mapping) {
178+
MLIRContext *ctx = unwrap(mlirCtx);
179+
auto dimAttr = llvm::cast<wave::WaveSymbolAttr>(unwrap(dimension));
180+
auto mapAttr = llvm::cast<wave::WaveIndexMappingAttr>(unwrap(mapping));
181+
return wrap(wave::WaveIndexEntryAttr::get(ctx, dimAttr, mapAttr));
182+
}
183+
184+
MlirTypeID mlirWaveIndexEntryAttrGetTypeID() {
185+
return wrap(TypeID::get<wave::WaveIndexEntryAttr>());
186+
}
187+
188+
MlirAttribute mlirWaveIndexEntryAttrGetDimension(MlirAttribute attr) {
189+
return wrap(
190+
llvm::cast<wave::WaveIndexEntryAttr>(unwrap(attr)).getDimension());
191+
}
192+
193+
MlirAttribute mlirWaveIndexEntryAttrGetMapping(MlirAttribute attr) {
194+
return wrap(llvm::cast<wave::WaveIndexEntryAttr>(unwrap(attr)).getMapping());
195+
}
196+
197+
//===---------------------------------------------------------------------===//
198+
// WaveIndexExprsAttr
199+
//===---------------------------------------------------------------------===//
200+
201+
bool mlirAttributeIsAWaveIndexExprsAttr(MlirAttribute attr) {
202+
return llvm::isa<wave::WaveIndexExprsAttr>(unwrap(attr));
203+
}
204+
205+
MlirAttribute mlirWaveIndexExprsAttrGet(MlirContext mlirCtx,
206+
intptr_t numEntries,
207+
MlirAttribute *entries) {
208+
MLIRContext *ctx = unwrap(mlirCtx);
209+
llvm::SmallVector<wave::WaveIndexEntryAttr> entryAttrs;
210+
entryAttrs.reserve(numEntries);
211+
for (intptr_t i = 0; i < numEntries; ++i) {
212+
entryAttrs.push_back(
213+
llvm::cast<wave::WaveIndexEntryAttr>(unwrap(entries[i])));
214+
}
215+
return wrap(wave::WaveIndexExprsAttr::get(ctx, entryAttrs));
216+
}
217+
218+
MlirTypeID mlirWaveIndexExprsAttrGetTypeID() {
219+
return wrap(TypeID::get<wave::WaveIndexExprsAttr>());
220+
}
221+
222+
intptr_t mlirWaveIndexExprsAttrGetNumEntries(MlirAttribute attr) {
223+
return llvm::cast<wave::WaveIndexExprsAttr>(unwrap(attr)).getEntries().size();
224+
}
225+
226+
MlirAttribute mlirWaveIndexExprsAttrGetEntry(MlirAttribute attr,
227+
intptr_t index) {
228+
return wrap(
229+
llvm::cast<wave::WaveIndexExprsAttr>(unwrap(attr)).getEntries()[index]);
230+
}
231+
167232
//===---------------------------------------------------------------------===//
168233
// WaveHyperparameterAttr
169234
//===---------------------------------------------------------------------===//

water/lib/Dialect/Wave/IR/WaveAttrs.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,12 @@ Attribute WaveIndexEntryAttr::parse(AsmParser &parser, Type type) {
722722

723723
void WaveIndexEntryAttr::print(AsmPrinter &printer) const {
724724
// Print: @M : [symbols] -> (start, step, stride)
725-
printer.printAttributeWithoutType(getDimension());
725+
// Use printStrippedAttrOrType to avoid the #wave.symbol<...> prefix in
726+
// generic mode, so we get just <"M"> instead of #wave.symbol<"M">.
727+
printer << " ";
728+
printer.printStrippedAttrOrType(getDimension());
726729
printer << " : ";
727-
printer.printAttributeWithoutType(getMapping());
730+
printer.printStrippedAttrOrType(getMapping());
728731
}
729732

730733
//===----------------------------------------------------------------------===//
@@ -756,9 +759,10 @@ Attribute WaveIndexExprsAttr::parse(AsmParser &parser, Type type) {
756759

757760
void WaveIndexExprsAttr::print(AsmPrinter &printer) const {
758761
// Print: <[@M : <mapping>, @K : <mapping>]>
762+
// Use printStrippedAttrOrType to avoid #wave<index_entry ...> prefix.
759763
printer << "<[";
760764
llvm::interleaveComma(getEntries(), printer, [&](WaveIndexEntryAttr entry) {
761-
printer.printAttributeWithoutType(entry);
765+
printer.printStrippedAttrOrType(entry);
762766
});
763767
printer << "]>";
764768
}

0 commit comments

Comments
 (0)