-
Notifications
You must be signed in to change notification settings - Fork 630
[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op #4278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general it looks good to me. My main feedback is:
- Can this be generalized for ND if the pixel shuffle operator allows ND?
- What is the expected batch dimension behavior? Can this be tested?
Thanks!
projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py
Show resolved
Hide resolved
Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks |
This has been captured in e2e tests. Thanks for your feedback. |
LGTM This PR has tests for rank 3, 4, and 5 which cover the case of no batch dimension, 1 batch dimension and 2 batch dimensions. All my concerns are addressed in the PR. Looks good to me. |
Hi everyone, a kind reminder to provide feedback, please. This PR adds support of torch.aten.pixel_unshuffle op. @rsuderman @zjgarvey @penguin-wwy @newling @sahas3 @ramiro050 @qedawkins @vivekkhandelwal1 |
5f09b62
to
82215f3
Compare
This PR will fix the following issue:
Add lowering of torch.aten.pixel_unshuffle op to linalg dialect
This code snippet can reproduce the issue:
The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and PyTorch implementation could be found in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp
With code changes, torch.aten.pixel_unshuffle will be lowered to the following: