-
Notifications
You must be signed in to change notification settings - Fork 28
[water] Replace DictionaryAttr with WaveIndexExprsAttr for ordered index expressions #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -519,4 +519,83 @@ def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> { | |
| }]; | ||
| } | ||
|
|
||
| //----------------------------------------------------------------------------- | ||
| // Index expression attributes (ordered) | ||
| //----------------------------------------------------------------------------- | ||
|
|
||
| def WaveIndexEntryAttr : AttrDef<WaveDialect, "WaveIndexEntry"> { | ||
| let mnemonic = "index_entry"; | ||
| let description = [{ | ||
| A single entry mapping a tensor dimension symbol to its index mapping. | ||
|
|
||
| This is a component of WaveIndexExprsAttr, representing one dimension's | ||
| index expression. The dimension is a WaveSymbolAttr (e.g., @M, @K, @N) | ||
| and the mapping is a WaveIndexMappingAttr specifying start/step/stride. | ||
|
|
||
| Syntax: @M : [symbols] -> (start, step, stride) | ||
| Example: @M : [#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M, 1, BLOCK_M) | ||
|
Comment on lines
+535
to
+536
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. THis PR literally has tests that show different syntax. Don't put examples in here, they are bitrotten before you even sent your PR! |
||
| }]; | ||
|
|
||
| let parameters = (ins | ||
| "::wave::WaveSymbolAttr":$dimension, | ||
| "::wave::WaveIndexMappingAttr":$mapping | ||
| ); | ||
|
|
||
| let hasCustomAssemblyFormat = 1; | ||
| } | ||
|
|
||
| def WaveIndexExprsAttr : AttrDef<WaveDialect, "WaveIndexExprs"> { | ||
| let mnemonic = "index_exprs"; | ||
| let description = [{ | ||
| An ordered collection of dimension index mappings for Wave tensors. | ||
|
|
||
| Unlike DictionaryAttr which sorts entries alphabetically, this attribute | ||
| preserves the order of entries as specified. The order corresponds to | ||
| the dimension order in the associated WaveTensorType's shape. | ||
|
|
||
| Syntax: index_exprs<[@M : <mapping>, @K : <mapping>, @N : <mapping>]> | ||
|
|
||
| The entries are stored in an array, maintaining insertion order. | ||
| This is critical for lowering where dimension order must match | ||
| the tensor type's shape, even after the tensor type has been converted | ||
| to memref. | ||
| }]; | ||
|
|
||
| let parameters = (ins | ||
| ArrayRefParameter<"::wave::WaveIndexEntryAttr">:$entries | ||
| ); | ||
|
|
||
| let hasCustomAssemblyFormat = 1; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| /// Look up the index mapping for a given dimension symbol. | ||
| /// Returns std::nullopt if the dimension is not found. | ||
| /// Complexity: O(n) where n is the number of dimensions (typically 2-4). | ||
| std::optional<::wave::WaveIndexMappingAttr> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You never need optional around a type that is nullable. Just use null as failure. Otherwise you have two "failure" states with no clear difference. |
||
| lookup(::wave::WaveSymbolAttr dimension) const; | ||
|
|
||
| /// Look up by dimension name string. | ||
| std::optional<::wave::WaveIndexMappingAttr> | ||
| lookup(::llvm::StringRef dimensionName) const; | ||
|
Comment on lines
+577
to
+579
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to avoid string-based APIs as much as possible. If we don't need it, let's remove this. |
||
|
|
||
| /// Get the ordered list of dimension symbols. | ||
| ::llvm::SmallVector<::wave::WaveSymbolAttr> getDimensions() const; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think you can make the return type |
||
|
|
||
| /// Get the number of entries. | ||
| size_t size() const { return getEntries().size(); } | ||
|
|
||
| /// Check if empty. | ||
| bool empty() const { return getEntries().empty(); } | ||
| }]; | ||
| } | ||
|
|
||
| //----------------------------------------------------------------------------- | ||
| // Typed array attributes | ||
| //----------------------------------------------------------------------------- | ||
|
|
||
| def WaveIndexExprsArrayAttr : TypedArrayAttrBase<WaveIndexExprsAttr, | ||
| "array of WaveIndexExprsAttr"> { | ||
| let constBuilderCall = "$_builder.getArrayAttr($0)"; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we leave the |
||
| } | ||
|
|
||
| #endif // WATER_DIALECT_WAVE_WAVEATTRS | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of having entry as an attribute? Do we ever need a single entry? Otherwise, we are just increasing the cost of manipulating these objects without benefit: attributes are created under lock and accessing each attribute adds a pointer indirection to get to the context-owned memory. Unless we use or intend to use index entries separately, we just store an array of pairs, pretty much like
DictAttrdoes.