Skip to content

Conversation

alaa-ali
Copy link
Contributor

@alaa-ali alaa-ali commented Jul 18, 2025

This PR will fix the following issue:
Add lowering of torch.aten.pixel_unshuffle op to linalg dialect

This code snippet can reproduce the issue:

func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
  %int2 = torch.constant.int 2
  %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
  return %0 : !torch.vtensor<[1,32,2,2],f32>
}

The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and PyTorch implementation could be found in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp

With code changes, torch.aten.pixel_unshuffle will be lowered to the following:

module {
  func.func @main(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int2 = torch.constant.int 2
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int3 = torch.constant.int 3
    %int4 = torch.constant.int 4
    %int5 = torch.constant.int 5
    %0 = torch.prim.ListConstruct %int0, %int1, %int3, %int5, %int2, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prims.split_dim %arg0, %int2, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32>
    %2 = torch.prims.split_dim %1, %int4, %int2 : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32>
    %3 = torch.aten.permute %2, %0 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32>
    %4 = torch.prims.collapse %3, %int2, %int3 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32>
    %5 = torch.prims.collapse %4, %int1, %int2 : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
    return %5 : !torch.vtensor<[1,32,2,2],f32>
  }
}

@alaa-ali
Copy link
Contributor Author

Copy link
Contributor

@ivangarcia44 ivangarcia44 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it looks good to me. My main feedback is:

  1. Can this be generalized for ND if the pixel shuffle operator allows ND?
  2. What is the expected batch dimension behavior? Can this be tested?

Thanks!

@ivangarcia44
Copy link
Contributor

Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks

@alaa-ali
Copy link
Contributor Author

alaa-ali commented Aug 1, 2025

Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks

This has been captured in e2e tests. Thanks for your feedback.

@ivangarcia44
Copy link
Contributor

LGTM

This PR has tests for rank 3, 4, and 5 which cover the case of no batch dimension, 1 batch dimension and 2 batch dimensions. All my concerns are addressed in the PR. Looks good to me.

@alaa-ali
Copy link
Contributor Author

alaa-ali commented Aug 4, 2025

Hi everyone, a kind reminder to provide feedback, please. This PR adds support of torch.aten.pixel_unshuffle op.
Thank you

@rsuderman @zjgarvey @penguin-wwy @newling @sahas3 @ramiro050 @qedawkins @vivekkhandelwal1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants