Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions lit_tests/kernel/wave/mlir_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,48 +248,48 @@ def mlir_converter_matrix_add():
# CHECK-SAME: N = 128 : i64

# CHECK: %[[READ_A:.*]] = wave.read %[[ARG0]]
# CHECK-SAME: index
# CHECK-SAME: M : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: N : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: index
# CHECK-SAME: <"M"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: <"N"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>

# CHECK: %[[READ_B:.*]] = wave.read %[[ARG1]]
# CHECK-SAME: index
# CHECK-SAME: M : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: N : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: index
# CHECK-SAME: <"M"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: <"N"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>

# CHECK: %[[ADD:.*]] = wave.add %[[READ_A]], %[[READ_B]]
# CHECK-SAME: index
# CHECK-SAME: M : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: N : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: <"M"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: <"N"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <register>>, !wave.tensor<[@M, @N] of f16, <register>>) -> !wave.tensor<[@M, @N] of f16, <register>>

# CHECK: %[[CAST:.*]] = wave.cast %[[ADD]]
# CHECK-SAME: index
# CHECK-SAME: M : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: N : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: <"M"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: <"N"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: : !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@M, @N] of f32, <register>>

# CHECK: wave.write %[[CAST]], %[[ARG2]]
# CHECK-SAME: index
# CHECK-SAME: M : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: N : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: index
# CHECK-SAME: <"M"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)
# CHECK-SAME: <"N"> : [{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)
# CHECK-SAME: !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <global>>

# CHECK: return
Expand Down Expand Up @@ -425,7 +425,7 @@ def pipeline(root: OpHandle):
# Python propagation algorithm that is immediately caught by the verifier on construction.
#
# CHECK-SAME: index =
# CHECK-SAME: K = #wave<index_mapping
# CHECK-SAME: <"K"> :
# CHECK-NOT: ARGK
#
# CHECK-NEXT: %[[ALLOCATE_2:.*]] = wave.allocate
Expand All @@ -446,20 +446,24 @@ def pipeline(root: OpHandle):
# CHECK-NEXT: %[[READ_SHARED_B_2:.*]] = wave.read %[[ALLOCATE_2]]
# CHECK-NEXT: %[[READ_SHARED_B_3:.*]] = wave.read %[[ALLOCATE_2]]
# CHECK-NEXT: %[[MMA_0:.*]] = wave.mma %[[READ_SHARED_B_0]], %[[READ_SHARED_A_0]], %[[ARG3]]
# CHECK-COUNT-2: {K : [
# CHECK-SAME: {M : [
# CHECK-SAME: index =
# CHECK-COUNT-2: <"K"> :
# CHECK-SAME: <"M"> :
# CHECK-SAME: #wave.mma_kind<f32_32x32x8_f16>
# CHECK-NEXT: %[[MMA_1:.*]] = wave.mma %[[READ_SHARED_B_1]], %[[READ_SHARED_A_1]], %[[MMA_0]]
# CHECK-COUNT-2: {K : [
# CHECK-SAME: {M : [
# CHECK-SAME: index =
# CHECK-COUNT-2: <"K"> :
# CHECK-SAME: <"M"> :
# CHECK-SAME: #wave.mma_kind<f32_32x32x8_f16>
# CHECK-NEXT: %[[MMA_2:.*]] = wave.mma %[[READ_SHARED_B_2]], %[[READ_SHARED_A_2]], %[[MMA_1]]
# CHECK-COUNT-2: {K : [
# CHECK-SAME: {M : [
# CHECK-SAME: index =
# CHECK-COUNT-2: <"K"> :
# CHECK-SAME: <"M"> :
# CHECK-SAME: #wave.mma_kind<f32_32x32x8_f16>
# CHECK-NEXT: %[[MMA_3:.*]] = wave.mma %[[READ_SHARED_B_3]], %[[READ_SHARED_A_3]], %[[MMA_2]]
# CHECK-COUNT-2: {K : [
# CHECK-SAME: {M : [
# CHECK-SAME: index =
# CHECK-COUNT-2: <"K"> :
# CHECK-SAME: <"M"> :
# CHECK-SAME: #wave.mma_kind<f32_32x32x8_f16>
# CHECK-NEXT: wave.yield %[[MMA_3]] : !wave.tensor<[@M, @N] of f32, <register>>
# CHECK-NEXT: }
Expand Down
70 changes: 52 additions & 18 deletions tests/mlir_wave_iface/mlir_to_wave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from mlir_converter.mlir_to_wave import (
_convert_affine_expr_to_sympy_expr,
_convert_index_mapping_attr_to_sympy,
_convert_index_mapping_dict_to_sympy,
_convert_index_exprs_to_sympy,
convert_index_mapping_array_to_sympy,
_make_piecewise_sequence,
ITER_SYMBOL_NAME_WAVE_PREFIX,
Expand Down Expand Up @@ -308,11 +308,11 @@ def test_index_mapping_with_null_step_stride(self):
assert result.stride is None


class TestConvertIndexMappingDictToSympy:
"""Tests for _convert_index_mapping_dict_to_sympy function."""
class TestConvertIndexExprsToSympy:
"""Tests for _convert_index_exprs_to_sympy function."""

def test_single_mapping(self):
"""Test conversion of dict with single index mapping."""
"""Test conversion of WaveIndexExprsAttr with single entry."""
# Create a simple index mapping
symbols = [wave.WaveSymbolAttr.get("M")]
s0 = ir.AffineSymbolExpr.get(0)
Expand All @@ -323,10 +323,12 @@ def test_single_mapping(self):
symbols, start_map, step_map, stride_map
)

# Create dict attribute
dict_attr = ir.DictAttr.get({"dim0": mapping_attr})
# Create WaveIndexExprsAttr with a single entry
dim_attr = wave.WaveSymbolAttr.get("dim0")
entry = wave.WaveIndexEntryAttr.get(dim_attr, mapping_attr)
index_exprs_attr = wave.WaveIndexExprsAttr.get([entry])

result = _convert_index_mapping_dict_to_sympy(dict_attr)
result = _convert_index_exprs_to_sympy(index_exprs_attr)

assert isinstance(result, dict)
assert index_symbol("dim0") in result
Expand All @@ -335,7 +337,7 @@ def test_single_mapping(self):
assert result[index_symbol("dim0")].size == 16

def test_multiple_mappings(self):
"""Test conversion of dict with multiple index mappings."""
"""Test conversion of WaveIndexExprsAttr with multiple entries."""
# Create first mapping
symbols1 = [wave.WaveSymbolAttr.get("M")]
s0 = ir.AffineSymbolExpr.get(0)
Expand All @@ -355,8 +357,14 @@ def test_multiple_mappings(self):
ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(2)]),
)

dict_attr = ir.DictAttr.get({"m": mapping1, "n": mapping2})
result = _convert_index_mapping_dict_to_sympy(dict_attr)
# Create WaveIndexExprsAttr with multiple entries
dim_m = wave.WaveSymbolAttr.get("m")
dim_n = wave.WaveSymbolAttr.get("n")
entry1 = wave.WaveIndexEntryAttr.get(dim_m, mapping1)
entry2 = wave.WaveIndexEntryAttr.get(dim_n, mapping2)
index_exprs_attr = wave.WaveIndexExprsAttr.get([entry1, entry2])

result = _convert_index_exprs_to_sympy(index_exprs_attr)

assert len(result) == 2
assert index_symbol("m") in result
Expand Down Expand Up @@ -414,8 +422,11 @@ def test_non_mma_op_single_mapping(self):
ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(1)]),
)

dict_attr = ir.DictAttr.get({"dim": mapping})
array_attr = ir.ArrayAttr.get([dict_attr])
# Create WaveIndexExprsAttr
dim_attr = wave.WaveSymbolAttr.get("dim")
entry = wave.WaveIndexEntryAttr.get(dim_attr, mapping)
index_exprs_attr = wave.WaveIndexExprsAttr.get([entry])
array_attr = ir.ArrayAttr.get([index_exprs_attr])

# We don't need anything from the operation except its name, so use an empty module.
dummy_op = ir.Operation.create("builtin.module", loc=ir.Location.unknown())
Expand Down Expand Up @@ -477,13 +488,36 @@ def test_mma_op_with_valid_four_mappings(self):
ir.AffineMap.get(0, 1, [c1]),
)

# Note that result mapping is the same as the accumulator mapping.
lhs_dict = ir.DictAttr.get({"M": lhs_m_mapping, "K": lhs_k_mapping})
rhs_dict = ir.DictAttr.get({"N": rhs_n_mapping, "K": rhs_k_mapping})
acc_dict = ir.DictAttr.get({"M": acc_m_mapping, "N": acc_n_mapping})
result_dict = ir.DictAttr.get({"M": acc_m_mapping, "N": acc_n_mapping})
# Create WaveIndexExprsAttr for each operand
m_dim = wave.WaveSymbolAttr.get("M")
n_dim = wave.WaveSymbolAttr.get("N")
k_dim = wave.WaveSymbolAttr.get("K")

lhs_entries = [
wave.WaveIndexEntryAttr.get(m_dim, lhs_m_mapping),
wave.WaveIndexEntryAttr.get(k_dim, lhs_k_mapping),
]
rhs_entries = [
wave.WaveIndexEntryAttr.get(n_dim, rhs_n_mapping),
wave.WaveIndexEntryAttr.get(k_dim, rhs_k_mapping),
]
acc_entries = [
wave.WaveIndexEntryAttr.get(m_dim, acc_m_mapping),
wave.WaveIndexEntryAttr.get(n_dim, acc_n_mapping),
]
result_entries = [
wave.WaveIndexEntryAttr.get(m_dim, acc_m_mapping),
wave.WaveIndexEntryAttr.get(n_dim, acc_n_mapping),
]

lhs_index_exprs = wave.WaveIndexExprsAttr.get(lhs_entries)
rhs_index_exprs = wave.WaveIndexExprsAttr.get(rhs_entries)
acc_index_exprs = wave.WaveIndexExprsAttr.get(acc_entries)
result_index_exprs = wave.WaveIndexExprsAttr.get(result_entries)

array_attr = ir.ArrayAttr.get([lhs_dict, rhs_dict, acc_dict, result_dict])
array_attr = ir.ArrayAttr.get(
[lhs_index_exprs, rhs_index_exprs, acc_index_exprs, result_index_exprs]
)

# Create a mock MMA operation, we only need the name, it doesn't even need to verify correctly.
dummy_mma_op = ir.Operation.create("wave.mma", loc=ir.Location.unknown())
Expand Down
79 changes: 79 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -519,4 +519,83 @@ def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> {
}];
}

//-----------------------------------------------------------------------------
// Index expression attributes (ordered)
//-----------------------------------------------------------------------------

def WaveIndexEntryAttr : AttrDef<WaveDialect, "WaveIndexEntry"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of having entry as an attribute? Do we ever need a single entry? Otherwise, we are just increasing the cost of manipulating these objects without benefit: attributes are created under lock and accessing each attribute adds a pointer indirection to get to the context-owned memory. Unless we use or intend to use index entries separately, we just store an array of pairs, pretty much like DictAttr does.

let mnemonic = "index_entry";
let description = [{
A single entry mapping a tensor dimension symbol to its index mapping.

This is a component of WaveIndexExprsAttr, representing one dimension's
index expression. The dimension is a WaveSymbolAttr (e.g., @M, @K, @N)
and the mapping is a WaveIndexMappingAttr specifying start/step/stride.

Syntax: @M : [symbols] -> (start, step, stride)
Example: @M : [#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M, 1, BLOCK_M)
Comment on lines +535 to +536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis PR literally has tests that show different syntax. Don't put examples in here, they are bitrotten before you even sent your PR!

}];

let parameters = (ins
"::wave::WaveSymbolAttr":$dimension,
"::wave::WaveIndexMappingAttr":$mapping
);

let hasCustomAssemblyFormat = 1;
}

def WaveIndexExprsAttr : AttrDef<WaveDialect, "WaveIndexExprs"> {
let mnemonic = "index_exprs";
let description = [{
An ordered collection of dimension index mappings for Wave tensors.

Unlike DictionaryAttr which sorts entries alphabetically, this attribute
preserves the order of entries as specified. The order corresponds to
the dimension order in the associated WaveTensorType's shape.

Syntax: index_exprs<[@M : <mapping>, @K : <mapping>, @N : <mapping>]>

The entries are stored in an array, maintaining insertion order.
This is critical for lowering where dimension order must match
the tensor type's shape, even after the tensor type has been converted
to memref.
}];

let parameters = (ins
ArrayRefParameter<"::wave::WaveIndexEntryAttr">:$entries
);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Look up the index mapping for a given dimension symbol.
/// Returns std::nullopt if the dimension is not found.
/// Complexity: O(n) where n is the number of dimensions (typically 2-4).
std::optional<::wave::WaveIndexMappingAttr>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You never need optional around a type that is nullable. Just use null as failure. Otherwise you have two "failure" states with no clear difference.

lookup(::wave::WaveSymbolAttr dimension) const;

/// Look up by dimension name string.
std::optional<::wave::WaveIndexMappingAttr>
lookup(::llvm::StringRef dimensionName) const;
Comment on lines +577 to +579
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid string-based APIs as much as possible. If we don't need it, let's remove this.


/// Get the ordered list of dimension symbols.
::llvm::SmallVector<::wave::WaveSymbolAttr> getDimensions() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think you can make the return type auto if you put implementation here, and just return llvm::map_range(...) (the actual type is atrocious), which will avoid the need to construct and potentially copy a vector.


/// Get the number of entries.
size_t size() const { return getEntries().size(); }

/// Check if empty.
bool empty() const { return getEntries().empty(); }
}];
}

//-----------------------------------------------------------------------------
// Typed array attributes
//-----------------------------------------------------------------------------

def WaveIndexExprsArrayAttr : TypedArrayAttrBase<WaveIndexExprsAttr,
"array of WaveIndexExprsAttr"> {
let constBuilderCall = "$_builder.getArrayAttr($0)";
Copy link
Contributor

@tgymnich tgymnich Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave the constBuilderCall out and just use the default builder implementation?

}

#endif // WATER_DIALECT_WAVE_WAVEATTRS
Loading