Skip to content

Conversation

@tgymnich
Copy link
Contributor

@tgymnich tgymnich commented Jan 20, 2026

Add wave.permute op that permutes dimensions of a WaveTensorInRegister.
Lowers to a no-op, but modifies propagation of index expressions by permuting the strides according to target_shape.

@tgymnich tgymnich linked an issue Jan 20, 2026 that may be closed by this pull request
@tgymnich tgymnich force-pushed the tim/permute-op branch 4 times, most recently from 474d06e to 9dd20e3 Compare January 22, 2026 14:24
@tgymnich tgymnich marked this pull request as ready for review January 22, 2026 14:27
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave.permute uses the new CompatibleOperandsAndResultsIgnoreShapeOpTrait as verifier for its input and return values, which ignores the shape entirely.
But for a permute the rank should match, and the target_shape should be a permutation of the input symbols. I think we should enforce this in the verifier, otherwise we can end up with mismatched ranks or symbol sets that later assert (e.g., zip_equal) or mis‑infer.

Also, I think I found a couple of propagation issues I’ve flagged inline.

"iterate": IterateOp,
"output": YieldOp,
"write": WriteOp,
"permute": PermuteOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When adding it here please add a permute op to one of the pywave → wave dialect tests or create a new test for it so we know it actually works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added tests

resultType = WaveTensorType::get(
getContext(), targetShape, /*fully_specified=*/true,
resultType ? resultType.getElementType() : inputType.getElementType(),
resultType ? resultType.getAddressSpace() : inputType.getAddressSpace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this ternary can lead to loss of information when resultType.getAddressSpace() == wave::WaveAddressSpace::Unspecified and inputType actually has a specified address space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also propagateForward only checks that resultType matches target_shape (when the result is already fully specified). It never validates that target_shape is a permutation of the input shape. That means you can end up setting a fully‑specified result whose rank/symbols don’t correspond to the input, and the index‑expr code later assumes equal sizes. Let's add an explicit input<>target permutation check (or enforce it in the op verifier)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely should be in the op verifier, but we may also want to have another diagnostic here. It may be better for UX to say "inference resulted in shape conflict" than just say "shape conflict" on shape-less input DSL.

Comment on lines 517 to 519
compilation. At lowering time, the operation is a pass-through since the
actual data layout in registers remains unchanged - only the interpretation
of which dimension each element belongs to changes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have always found this concept of pass-through permutation difficult to grasp, can this be illustrated with an example?

}];
let arguments = !con((ins
Arg<WaveTensorInRegister, "Value to permute">:$value,
Arg<WaveSymbolArrayAttr , "Target dimension ordering">:$target_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this attribute? When the result type is a tensor, it duplicates its shape. Is there a situation where we need it when the result type was lowered to a vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Less to verify. We shouldn't need the shape downstream.

Comment on lines 1861 to 1877
// If result is already fully specified, verify it matches target_shape.
if (resultType && resultType.getFullySpecified()) {
ArrayRef<WaveSymbolAttr> resultShape = resultType.getShape();
if (resultShape.size() != targetShape.size()) {
errs << "result shape rank (" << resultShape.size()
<< ") does not match target_shape rank (" << targetShape.size()
<< ")";
return failure();
}
for (auto [i, expected, actual] :
llvm::enumerate(targetShape, resultShape)) {
if (expected != actual) {
errs << "result shape dimension #" << i << " (" << actual
<< ") does not match target_shape (" << expected << ")";
return failure();
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a helper function for this, checkPropagateShapeConflict.

resultType = WaveTensorType::get(
getContext(), targetShape, /*fully_specified=*/true,
resultType ? resultType.getElementType() : inputType.getElementType(),
resultType ? resultType.getAddressSpace() : inputType.getAddressSpace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely should be in the op verifier, but we may also want to have another diagnostic here. It may be better for UX to say "inference resulted in shape conflict" than just say "shape conflict" on shape-less input DSL.

resultType ? resultType.getElementType() : inputType.getElementType(),
resultType ? resultType.getAddressSpace() : inputType.getAddressSpace());

return ChangeResult::Change;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check if the result type actually changes. Otherwise the analysis may never converge since we always indicate change. Does it?

Comment on lines 1906 to 1918
// Verify result shape matches target_shape.
if (resultShape.size() != targetShape.size()) {
errs << "result shape rank (" << resultShape.size()
<< ") does not match target_shape rank (" << targetShape.size() << ")";
return failure();
}
for (auto [i, expected, actual] : llvm::enumerate(targetShape, resultShape)) {
if (expected != actual) {
errs << "result shape dimension #" << i << " (" << actual
<< ") does not match target_shape (" << expected << ")";
return failure();
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we should have a helper doing this.

@tgymnich tgymnich force-pushed the tim/permute-op branch 4 times, most recently from f0bcf95 to 47200f9 Compare January 28, 2026 09:54
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the changes this LGTM with a few nits still to address

resultShapeSet.insert_range(resultType.getShape());

for (auto inputDim : inputType.getShape()) {
auto [_, inserted] = resultShapeSet.insert(inputDim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you not use contains() to check if the inputDim is in the set? Seems more idiomatic
e.g.
if (!resultShapeSet.contains(inputDim)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

Comment on lines +1911 to +1917
if (inputType.getFullySpecified()) {
std::string errorMessage;
llvm::raw_string_ostream errs(errorMessage);
if (failed(validatePermutationInput(inputType, resultType, errs))) {
return emitOpError() << errorMessage;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass in the op into validatePermutationInput as well, you wouldn't need the string here and could emit the error directly from inside it.
I think the MLIR idiomatic approach to passing error messages is to construct InFlightDiagnostic from the op and then compose it within validatePermutationInput.
However, this would require calling abandon() on it in case there was no error. I think the first approach is cleaner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to how propagateForward and propagateBackward emit errors. If we directly emit here it will mess up the order of the messages. Might be something to fix for later, e.g. by passing a InFlightDiagnostic and adding notes.

Comment on lines +206 to +215
// CHECK: !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>>
wave.permute %a : !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>>
return
}

// CHECK-LABEL: @propagate_permute_2d
func.func @propagate_permute_2d(%a: !wave.tensor<[@M, @N] of f16, <register>>) {
// CHECK: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
wave.permute %a : !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These CHECK lines only check parsing of the op, I don't see anything regarding type inference.
Could these be just removed in favor of the tests in ops.mlir or do they test anything in addition?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can test the following:

%0 = wave.negate %arg0 : @A, @B, @C -> any
wave.permute %0 any to @M, @N, @K

this will not fail the verifier, but the type inference should discover and report the conflict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this does not really propagate anything anymore, but its good to keep to make sure it does not fail. I'll also add the negative case.

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, LGTM % nits

// PermuteOp
//-----------------------------------------------------------------------------

static LogicalResult validatePermutationInput(WaveTensorType inputType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit please document functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added docs

Comment on lines +1898 to +1899
// If result / input is a vector (post-lowering phase), skip wave tensor
// checks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we want to verify element types for vectors as well. I added a helper function recently.

}

// Result type is already specified, propagate it.
return detail::propagateShapeInformation(resultType, resultType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this do anything useful? It looks like this will always succeed and the resultType should already be initialized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines 1975 to 1980
if (srcShape.size() != targetShape.size()) {
emitError() << "source shape rank (" << srcShape.size()
<< ") does not match target shape rank (" << targetShape.size()
<< ")";
return IndexExprsLatticeStorage::top();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be an assertion. We check in the verifier that shapes have equal rank when present, and we also have a normal form precondition for this entire analysis that types are fully specified. Try it if needed, it shouldn't be possible to produce this error message.

Comment on lines 2006 to 2008
// If the target or source mapping is not found, we cannot propagate the
// index expression.
if (srcMappingIt == symbolToMapping.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen in IR that passed the verifier and satisfied the full types normal form? If not, this should be an assertion. Same below.

If it can, maybe we should support some sort of partial propagation only for symbols that are present in hopes that other symbols show up later as we keep propagating.

IndexExprsLatticeStorage permuted = permuteIndexExprsStrides(
operandExprs[0], srcShape, targetShape, getContext(), emitError);

permuted = permuted.keepOnlySymbols(resultType.getShape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? AFAIU, we should have strictly the same symbols before and after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. this is already handled.

IndexExprsLatticeStorage permuted = permuteIndexExprsStrides(
resultExprs[0], resultShape, srcShape, getContext(), emitError);

permuted = permuted.keepOnlySymbols(srcShape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines +206 to +215
// CHECK: !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>>
wave.permute %a : !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>>
return
}

// CHECK-LABEL: @propagate_permute_2d
func.func @propagate_permute_2d(%a: !wave.tensor<[@M, @N] of f16, <register>>) {
// CHECK: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
wave.permute %a : !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can test the following:

%0 = wave.negate %arg0 : @A, @B, @C -> any
wave.permute %0 any to @M, @N, @K

this will not fail the verifier, but the type inference should discover and report the conflict.

Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[water] Implement wave.permute

4 participants