-
Notifications
You must be signed in to change notification settings - Fork 25
[water] add wave.permute op #759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
474d06e to
9dd20e3
Compare
9dd20e3 to
dda879e
Compare
dda879e to
bb65e50
Compare
martin-luecke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wave.permute uses the new CompatibleOperandsAndResultsIgnoreShapeOpTrait as verifier for its input and return values, which ignores the shape entirely.
But for a permute the rank should match, and the target_shape should be a permutation of the input symbols. I think we should enforce this in the verifier, otherwise we can end up with mismatched ranks or symbol sets that later assert (e.g., zip_equal) or mis‑infer.
Also, I think I found a couple of propagation issues I’ve flagged inline.
| "iterate": IterateOp, | ||
| "output": YieldOp, | ||
| "write": WriteOp, | ||
| "permute": PermuteOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When adding it here please add a permute op to one of the pywave → wave dialect tests or create a new test for it so we know it actually works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added tests
| resultType = WaveTensorType::get( | ||
| getContext(), targetShape, /*fully_specified=*/true, | ||
| resultType ? resultType.getElementType() : inputType.getElementType(), | ||
| resultType ? resultType.getAddressSpace() : inputType.getAddressSpace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this ternary can lead to loss of information when resultType.getAddressSpace() == wave::WaveAddressSpace::Unspecified and inputType actually has a specified address space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also propagateForward only checks that resultType matches target_shape (when the result is already fully specified). It never validates that target_shape is a permutation of the input shape. That means you can end up setting a fully‑specified result whose rank/symbols don’t correspond to the input, and the index‑expr code later assumes equal sizes. Let's add an explicit input<>target permutation check (or enforce it in the op verifier)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definitely should be in the op verifier, but we may also want to have another diagnostic here. It may be better for UX to say "inference resulted in shape conflict" than just say "shape conflict" on shape-less input DSL.
| compilation. At lowering time, the operation is a pass-through since the | ||
| actual data layout in registers remains unchanged - only the interpretation | ||
| of which dimension each element belongs to changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have always found this concept of pass-through permutation difficult to grasp, can this be illustrated with an example?
| }]; | ||
| let arguments = !con((ins | ||
| Arg<WaveTensorInRegister, "Value to permute">:$value, | ||
| Arg<WaveSymbolArrayAttr , "Target dimension ordering">:$target_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this attribute? When the result type is a tensor, it duplicates its shape. Is there a situation where we need it when the result type was lowered to a vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Less to verify. We shouldn't need the shape downstream.
| // If result is already fully specified, verify it matches target_shape. | ||
| if (resultType && resultType.getFullySpecified()) { | ||
| ArrayRef<WaveSymbolAttr> resultShape = resultType.getShape(); | ||
| if (resultShape.size() != targetShape.size()) { | ||
| errs << "result shape rank (" << resultShape.size() | ||
| << ") does not match target_shape rank (" << targetShape.size() | ||
| << ")"; | ||
| return failure(); | ||
| } | ||
| for (auto [i, expected, actual] : | ||
| llvm::enumerate(targetShape, resultShape)) { | ||
| if (expected != actual) { | ||
| errs << "result shape dimension #" << i << " (" << actual | ||
| << ") does not match target_shape (" << expected << ")"; | ||
| return failure(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a helper function for this, checkPropagateShapeConflict.
| resultType = WaveTensorType::get( | ||
| getContext(), targetShape, /*fully_specified=*/true, | ||
| resultType ? resultType.getElementType() : inputType.getElementType(), | ||
| resultType ? resultType.getAddressSpace() : inputType.getAddressSpace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definitely should be in the op verifier, but we may also want to have another diagnostic here. It may be better for UX to say "inference resulted in shape conflict" than just say "shape conflict" on shape-less input DSL.
| resultType ? resultType.getElementType() : inputType.getElementType(), | ||
| resultType ? resultType.getAddressSpace() : inputType.getAddressSpace()); | ||
|
|
||
| return ChangeResult::Change; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to check if the result type actually changes. Otherwise the analysis may never converge since we always indicate change. Does it?
| // Verify result shape matches target_shape. | ||
| if (resultShape.size() != targetShape.size()) { | ||
| errs << "result shape rank (" << resultShape.size() | ||
| << ") does not match target_shape rank (" << targetShape.size() << ")"; | ||
| return failure(); | ||
| } | ||
| for (auto [i, expected, actual] : llvm::enumerate(targetShape, resultShape)) { | ||
| if (expected != actual) { | ||
| errs << "result shape dimension #" << i << " (" << actual | ||
| << ") does not match target_shape (" << expected << ")"; | ||
| return failure(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, we should have a helper doing this.
f0bcf95 to
47200f9
Compare
Signed-off-by: Tim Gymnich <[email protected]>
47200f9 to
b96f4d3
Compare
Signed-off-by: Tim Gymnich <[email protected]>
martin-luecke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the changes this LGTM with a few nits still to address
| resultShapeSet.insert_range(resultType.getShape()); | ||
|
|
||
| for (auto inputDim : inputType.getShape()) { | ||
| auto [_, inserted] = resultShapeSet.insert(inputDim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you not use contains() to check if the inputDim is in the set? Seems more idiomatic
e.g.
if (!resultShapeSet.contains(inputDim)) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
| if (inputType.getFullySpecified()) { | ||
| std::string errorMessage; | ||
| llvm::raw_string_ostream errs(errorMessage); | ||
| if (failed(validatePermutationInput(inputType, resultType, errs))) { | ||
| return emitOpError() << errorMessage; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you pass in the op into validatePermutationInput as well, you wouldn't need the string here and could emit the error directly from inside it.
I think the MLIR idiomatic approach to passing error messages is to construct InFlightDiagnostic from the op and then compose it within validatePermutationInput.
However, this would require calling abandon() on it in case there was no error. I think the first approach is cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is due to how propagateForward and propagateBackward emit errors. If we directly emit here it will mess up the order of the messages. Might be something to fix for later, e.g. by passing a InFlightDiagnostic and adding notes.
| // CHECK: !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>> | ||
| wave.permute %a : !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>> | ||
| return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @propagate_permute_2d | ||
| func.func @propagate_permute_2d(%a: !wave.tensor<[@M, @N] of f16, <register>>) { | ||
| // CHECK: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>> | ||
| wave.permute %a : !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>> | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These CHECK lines only check parsing of the op, I don't see anything regarding type inference.
Could these be just removed in favor of the tests in ops.mlir or do they test anything in addition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can test the following:
%0 = wave.negate %arg0 : @A, @B, @C -> any
wave.permute %0 any to @M, @N, @Kthis will not fail the verifier, but the type inference should discover and report the conflict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this does not really propagate anything anymore, but its good to keep to make sure it does not fail. I'll also add the negative case.
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, LGTM % nits
| // PermuteOp | ||
| //----------------------------------------------------------------------------- | ||
|
|
||
| static LogicalResult validatePermutationInput(WaveTensorType inputType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit please document functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added docs
| // If result / input is a vector (post-lowering phase), skip wave tensor | ||
| // checks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we want to verify element types for vectors as well. I added a helper function recently.
| } | ||
|
|
||
| // Result type is already specified, propagate it. | ||
| return detail::propagateShapeInformation(resultType, resultType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this do anything useful? It looks like this will always succeed and the resultType should already be initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
| if (srcShape.size() != targetShape.size()) { | ||
| emitError() << "source shape rank (" << srcShape.size() | ||
| << ") does not match target shape rank (" << targetShape.size() | ||
| << ")"; | ||
| return IndexExprsLatticeStorage::top(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be an assertion. We check in the verifier that shapes have equal rank when present, and we also have a normal form precondition for this entire analysis that types are fully specified. Try it if needed, it shouldn't be possible to produce this error message.
| // If the target or source mapping is not found, we cannot propagate the | ||
| // index expression. | ||
| if (srcMappingIt == symbolToMapping.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this happen in IR that passed the verifier and satisfied the full types normal form? If not, this should be an assertion. Same below.
If it can, maybe we should support some sort of partial propagation only for symbols that are present in hopes that other symbols show up later as we keep propagating.
| IndexExprsLatticeStorage permuted = permuteIndexExprsStrides( | ||
| operandExprs[0], srcShape, targetShape, getContext(), emitError); | ||
|
|
||
| permuted = permuted.keepOnlySymbols(resultType.getShape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? AFAIU, we should have strictly the same symbols before and after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right. this is already handled.
| IndexExprsLatticeStorage permuted = permuteIndexExprsStrides( | ||
| resultExprs[0], resultShape, srcShape, getContext(), emitError); | ||
|
|
||
| permuted = permuted.keepOnlySymbols(srcShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
| // CHECK: !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>> | ||
| wave.permute %a : !wave.tensor<[@B, @M, @N] of f32, <register>> to !wave.tensor<[@M, @N, @B] of f32, <register>> | ||
| return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @propagate_permute_2d | ||
| func.func @propagate_permute_2d(%a: !wave.tensor<[@M, @N] of f16, <register>>) { | ||
| // CHECK: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>> | ||
| wave.permute %a : !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>> | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can test the following:
%0 = wave.negate %arg0 : @A, @B, @C -> any
wave.permute %0 any to @M, @N, @Kthis will not fail the verifier, but the type inference should discover and report the conflict.
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Add
wave.permuteop that permutes dimensions of aWaveTensorInRegister.Lowers to a no-op, but modifies propagation of index expressions by permuting the strides according to
target_shape.