Skip to content

Commit 6b80129

Browse files
committed
[water] Add WaveIndexExprsAttr for ordered index expressions
Add new attributes to represent ordered index expressions that preserve dimension order, unlike DictionaryAttr which sorts entries alphabetically. This is the first step toward fixing the index expression ordering issue where dimension order is lost during lattice operations because: 1. DenseMap loses insertion order during join operations 2. DictionaryAttr::get() always sorts entries alphabetically New attributes: - WaveIndexEntryAttr: A single (dimension, mapping) pair - WaveIndexExprsAttr: An ordered array of entries The order corresponds to the tensor type's shape dimension order, which is critical for lowering when the tensor type has been converted to memref and is no longer available. Syntax: index_exprs<[@m : <mapping>, @k : <mapping>, @n : <mapping>]> Signed-off-by: tyb0807 <sontuan.vu@amd.com>
1 parent dc1c706 commit 6b80129

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

water/include/water/Dialect/Wave/IR/WaveAttrs.td

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,4 +519,74 @@ def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> {
519519
}];
520520
}
521521

522+
//-----------------------------------------------------------------------------
523+
// Index expression attributes (ordered)
524+
//-----------------------------------------------------------------------------
525+
526+
def WaveIndexEntryAttr : AttrDef<WaveDialect, "WaveIndexEntry"> {
527+
let mnemonic = "index_entry";
528+
let description = [{
529+
A single entry mapping a tensor dimension symbol to its index mapping.
530+
531+
This is a component of WaveIndexExprsAttr, representing one dimension's
532+
index expression. The dimension is a WaveSymbolAttr (e.g., @M, @K, @N)
533+
and the mapping is a WaveIndexMappingAttr specifying start/step/stride.
534+
535+
Syntax: @M : [symbols] -> (start, step, stride)
536+
Example: @M : [#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M, 1, BLOCK_M)
537+
}];
538+
539+
let parameters = (ins
540+
"::wave::WaveSymbolAttr":$dimension,
541+
"::wave::WaveIndexMappingAttr":$mapping
542+
);
543+
544+
let hasCustomAssemblyFormat = 1;
545+
}
546+
547+
def WaveIndexExprsAttr : AttrDef<WaveDialect, "WaveIndexExprs"> {
548+
let mnemonic = "index_exprs";
549+
let description = [{
550+
An ordered collection of dimension index mappings for Wave tensors.
551+
552+
Unlike DictionaryAttr which sorts entries alphabetically, this attribute
553+
preserves the order of entries as specified. The order corresponds to
554+
the dimension order in the associated WaveTensorType's shape.
555+
556+
Syntax: index_exprs<[@M : <mapping>, @K : <mapping>, @N : <mapping>]>
557+
558+
The entries are stored in an array, maintaining insertion order.
559+
This is critical for lowering where dimension order must match
560+
the tensor type's shape, even after the tensor type has been converted
561+
to memref.
562+
}];
563+
564+
let parameters = (ins
565+
ArrayRefParameter<"::wave::WaveIndexEntryAttr">:$entries
566+
);
567+
568+
let hasCustomAssemblyFormat = 1;
569+
570+
let extraClassDeclaration = [{
571+
/// Look up the index mapping for a given dimension symbol.
572+
/// Returns std::nullopt if the dimension is not found.
573+
/// Complexity: O(n) where n is the number of dimensions (typically 2-4).
574+
std::optional<::wave::WaveIndexMappingAttr>
575+
lookup(::wave::WaveSymbolAttr dimension) const;
576+
577+
/// Look up by dimension name string.
578+
std::optional<::wave::WaveIndexMappingAttr>
579+
lookup(::llvm::StringRef dimensionName) const;
580+
581+
/// Get the ordered list of dimension symbols.
582+
::llvm::SmallVector<::wave::WaveSymbolAttr> getDimensions() const;
583+
584+
/// Get the number of entries.
585+
size_t size() const { return getEntries().size(); }
586+
587+
/// Check if empty.
588+
bool empty() const { return getEntries().empty(); }
589+
}];
590+
}
591+
522592
#endif // WATER_DIALECT_WAVE_WAVEATTRS

water/lib/Dialect/Wave/IR/WaveAttrs.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,98 @@ DeviceConstraintAttr::verify(function_ref<InFlightDiagnostic()> emitError,
697697
return success();
698698
}
699699

700+
//===----------------------------------------------------------------------===//
701+
// WaveIndexEntryAttr
702+
//===----------------------------------------------------------------------===//
703+
704+
// Syntax: @M : [symbols] -> (start, step, stride)
705+
Attribute WaveIndexEntryAttr::parse(AsmParser &parser, Type type) {
706+
// Parse dimension symbol: @M
707+
WaveSymbolAttr dimension;
708+
if (parser.parseCustomAttributeWithFallback<WaveSymbolAttr>(dimension))
709+
return {};
710+
711+
// Parse colon
712+
if (parser.parseColon())
713+
return {};
714+
715+
// Parse the index mapping
716+
WaveIndexMappingAttr mapping;
717+
if (parser.parseCustomAttributeWithFallback<WaveIndexMappingAttr>(mapping))
718+
return {};
719+
720+
return get(parser.getContext(), dimension, mapping);
721+
}
722+
723+
void WaveIndexEntryAttr::print(AsmPrinter &printer) const {
724+
// Print: @M : [symbols] -> (start, step, stride)
725+
printer.printAttributeWithoutType(getDimension());
726+
printer << " : ";
727+
printer.printAttributeWithoutType(getMapping());
728+
}
729+
730+
//===----------------------------------------------------------------------===//
731+
// WaveIndexExprsAttr
732+
//===----------------------------------------------------------------------===//
733+
734+
// Syntax: index_exprs<[@M : <mapping>, @K : <mapping>, @N : <mapping>]>
735+
Attribute WaveIndexExprsAttr::parse(AsmParser &parser, Type type) {
736+
if (parser.parseLess())
737+
return {};
738+
739+
SmallVector<WaveIndexEntryAttr> entries;
740+
741+
// Parse '[' entries ']' allowing empty or non-empty lists
742+
if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
743+
WaveIndexEntryAttr entry;
744+
if (parser.parseCustomAttributeWithFallback<WaveIndexEntryAttr>(entry))
745+
return failure();
746+
entries.push_back(entry);
747+
return success();
748+
}))
749+
return {};
750+
751+
if (parser.parseGreater())
752+
return {};
753+
754+
return get(parser.getContext(), entries);
755+
}
756+
757+
void WaveIndexExprsAttr::print(AsmPrinter &printer) const {
758+
// Print: <[@M : <mapping>, @K : <mapping>]>
759+
printer << "<[";
760+
llvm::interleaveComma(getEntries(), printer, [&](WaveIndexEntryAttr entry) {
761+
printer.printAttributeWithoutType(entry);
762+
});
763+
printer << "]>";
764+
}
765+
766+
std::optional<WaveIndexMappingAttr>
767+
WaveIndexExprsAttr::lookup(WaveSymbolAttr dimension) const {
768+
for (WaveIndexEntryAttr entry : getEntries()) {
769+
if (entry.getDimension() == dimension)
770+
return entry.getMapping();
771+
}
772+
return std::nullopt;
773+
}
774+
775+
std::optional<WaveIndexMappingAttr>
776+
WaveIndexExprsAttr::lookup(StringRef dimensionName) const {
777+
for (WaveIndexEntryAttr entry : getEntries()) {
778+
if (entry.getDimension().getName() == dimensionName)
779+
return entry.getMapping();
780+
}
781+
return std::nullopt;
782+
}
783+
784+
SmallVector<WaveSymbolAttr> WaveIndexExprsAttr::getDimensions() const {
785+
SmallVector<WaveSymbolAttr> dims;
786+
dims.reserve(getEntries().size());
787+
for (WaveIndexEntryAttr entry : getEntries())
788+
dims.push_back(entry.getDimension());
789+
return dims;
790+
}
791+
700792
void wave::WaveDialect::registerAttributes() {
701793
addAttributes<
702794
#define GET_ATTRDEF_LIST

0 commit comments

Comments
 (0)