-
Notifications
You must be signed in to change notification settings - Fork 628
[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op #4278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general it looks good to me. My main feedback is:
- Can this be generalized for ND if the pixel shuffle operator allows ND?
- What is the expected batch dimension behavior? Can this be tested?
Thanks!
projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py
Show resolved
Hide resolved
Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks |
This has been captured in e2e tests. Thanks for your feedback. |
LGTM This PR has tests for rank 3, 4, and 5 which cover the case of no batch dimension, 1 batch dimension and 2 batch dimensions. All my concerns are addressed in the PR. Looks good to me. |
Hi everyone, a kind reminder to provide feedback, please. This PR adds support of torch.aten.pixel_unshuffle op. @rsuderman @zjgarvey @penguin-wwy @newling @sahas3 @ramiro050 @qedawkins @vivekkhandelwal1 |
// (*leading_dims, C, H*r, W*r), | ||
// | ||
// where leading_dims is of size N, then | ||
// X = pixel_unshuffle(input, downscale_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using r
here instead of downscale_factor
will be better for consistency. You can also move line 3731 mentioning r
is the downscale_factor above to help readability.
// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W) | ||
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need two collapses -- isn't collapsing directly to C*r*r
sufficient?
auto getTypeFromShape = [inOptionalDType](auto &&vals) { | ||
// Get a vector of integers from a vector of Values. | ||
auto getIntShape = [](auto &&vals) { | ||
SmallVector<int64_t> shape; | ||
shape.reserve(vals.size()); | ||
for (auto v : vals) { | ||
int64_t cst_val; | ||
if (matchPattern(v, m_TorchConstantInt(&cst_val))) { | ||
shape.push_back(cst_val); | ||
} else { | ||
shape.push_back(kUnknownSize); | ||
} | ||
} | ||
return shape; | ||
}; | ||
|
||
const auto intShape = getIntShape(vals); | ||
return ValueTensorType::get(vals[0].getContext(), | ||
llvm::ArrayRef(intShape), inOptionalDType); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These methods were added as utilities in #4259. Once that is merged, can you update your code to reuse the utilities?
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead | ||
// of 'create': if the dimension size is known, then the AtenSizeIntOp is | ||
// folded to a ConstantOp. | ||
auto getDimSize = [&](uint64_t i) -> Value { | ||
Value dim = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)); | ||
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is shared with #4259 as well. It'll be good to move this into an utility method and probably move all these utilities to lib/Conversion/Utils/Utils.cpp
as the shared location to be used elsewhere in the code base too.
SmallVector<Value> partiallyExpandedShape = leadingDims; | ||
partiallyExpandedShape.append({inC, outH, factor, inW}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this before it's use also rename to heightSplitShape
for readability?
This PR will fix the following issue:
Add lowering of torch.aten.pixel_unshuffle op to linalg dialect
This code snippet can reproduce the issue:
The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and PyTorch implementation could be found in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp
With code changes, torch.aten.pixel_unshuffle will be lowered to the following: