Commit b8f742b
authored
[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op (#4278)
This PR will fix the following issue:
[Add lowering of torch.aten.pixel_unshuffle op to linalg
dialect](#4260)
This code snippet can reproduce the issue:
```
func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int2 = torch.constant.int 2
%0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
return %0 : !torch.vtensor<[1,32,2,2],f32>
}
```
The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and PyTorch implementation could be found in
main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp
With code changes, torch.aten.pixel_unshuffle will be lowered to the
following:
```
module {
func.func @main(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%0 = torch.prim.ListConstruct %int0, %int1, %int3, %int5, %int2, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prims.split_dim %arg0, %int2, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32>
%2 = torch.prims.split_dim %1, %int4, %int2 : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32>
%3 = torch.aten.permute %2, %0 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32>
%4 = torch.prims.collapse %3, %int2, %int3 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32>
%5 = torch.prims.collapse %4, %int1, %int2 : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
return %5 : !torch.vtensor<[1,32,2,2],f32>
}
}
```1 parent 4f572c5 commit b8f742b
File tree
10 files changed
+428
-57
lines changed- include/torch-mlir/Dialect/Torch/IR
- lib/Dialect/Torch
- Transforms
- Utils
- projects/pt1
- e2e_testing
- python
- torch_mlir_e2e_test/test_suite
- torch_mlir/jit_ir_importer/build_tools
- test/Dialect/Torch
10 files changed
+428
-57
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8668 | 8668 | | |
8669 | 8669 | | |
8670 | 8670 | | |
| 8671 | + | |
| 8672 | + | |
| 8673 | + | |
| 8674 | + | |
| 8675 | + | |
| 8676 | + | |
| 8677 | + | |
| 8678 | + | |
| 8679 | + | |
| 8680 | + | |
| 8681 | + | |
| 8682 | + | |
| 8683 | + | |
| 8684 | + | |
| 8685 | + | |
| 8686 | + | |
| 8687 | + | |
| 8688 | + | |
| 8689 | + | |
| 8690 | + | |
| 8691 | + | |
| 8692 | + | |
| 8693 | + | |
| 8694 | + | |
8671 | 8695 | | |
8672 | 8696 | | |
8673 | 8697 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7613 | 7613 | | |
7614 | 7614 | | |
7615 | 7615 | | |
| 7616 | + | |
| 7617 | + | |
| 7618 | + | |
| 7619 | + | |
| 7620 | + | |
| 7621 | + | |
| 7622 | + | |
| 7623 | + | |
| 7624 | + | |
| 7625 | + | |
| 7626 | + | |
| 7627 | + | |
| 7628 | + | |
| 7629 | + | |
| 7630 | + | |
| 7631 | + | |
| 7632 | + | |
| 7633 | + | |
| 7634 | + | |
| 7635 | + | |
| 7636 | + | |
| 7637 | + | |
| 7638 | + | |
| 7639 | + | |
| 7640 | + | |
| 7641 | + | |
| 7642 | + | |
| 7643 | + | |
| 7644 | + | |
| 7645 | + | |
| 7646 | + | |
| 7647 | + | |
| 7648 | + | |
| 7649 | + | |
| 7650 | + | |
| 7651 | + | |
| 7652 | + | |
| 7653 | + | |
| 7654 | + | |
| 7655 | + | |
| 7656 | + | |
| 7657 | + | |
| 7658 | + | |
| 7659 | + | |
| 7660 | + | |
| 7661 | + | |
| 7662 | + | |
| 7663 | + | |
| 7664 | + | |
| 7665 | + | |
7616 | 7666 | | |
7617 | 7667 | | |
7618 | 7668 | | |
| |||
12411 | 12461 | | |
12412 | 12462 | | |
12413 | 12463 | | |
| 12464 | + | |
| 12465 | + | |
| 12466 | + | |
| 12467 | + | |
12414 | 12468 | | |
12415 | 12469 | | |
12416 | 12470 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3536 | 3536 | | |
3537 | 3537 | | |
3538 | 3538 | | |
3539 | | - | |
3540 | | - | |
3541 | | - | |
3542 | | - | |
3543 | | - | |
3544 | | - | |
3545 | | - | |
3546 | | - | |
3547 | | - | |
3548 | | - | |
3549 | | - | |
3550 | | - | |
3551 | | - | |
3552 | | - | |
3553 | | - | |
3554 | | - | |
3555 | | - | |
3556 | | - | |
3557 | | - | |
3558 | | - | |
3559 | | - | |
3560 | | - | |
3561 | | - | |
3562 | | - | |
3563 | 3539 | | |
3564 | 3540 | | |
3565 | 3541 | | |
| |||
3609 | 3585 | | |
3610 | 3586 | | |
3611 | 3587 | | |
3612 | | - | |
3613 | | - | |
3614 | | - | |
3615 | | - | |
3616 | | - | |
3617 | | - | |
3618 | | - | |
3619 | | - | |
3620 | | - | |
3621 | | - | |
3622 | | - | |
3623 | | - | |
| 3588 | + | |
| 3589 | + | |
| 3590 | + | |
3624 | 3591 | | |
3625 | 3592 | | |
3626 | 3593 | | |
| |||
3678 | 3645 | | |
3679 | 3646 | | |
3680 | 3647 | | |
3681 | | - | |
| 3648 | + | |
| 3649 | + | |
| 3650 | + | |
3682 | 3651 | | |
3683 | 3652 | | |
3684 | 3653 | | |
3685 | 3654 | | |
3686 | 3655 | | |
3687 | | - | |
| 3656 | + | |
3688 | 3657 | | |
3689 | 3658 | | |
3690 | 3659 | | |
3691 | 3660 | | |
3692 | | - | |
3693 | | - | |
| 3661 | + | |
| 3662 | + | |
3694 | 3663 | | |
3695 | 3664 | | |
3696 | 3665 | | |
3697 | | - | |
| 3666 | + | |
| 3667 | + | |
3698 | 3668 | | |
3699 | 3669 | | |
3700 | 3670 | | |
| |||
3709 | 3679 | | |
3710 | 3680 | | |
3711 | 3681 | | |
| 3682 | + | |
| 3683 | + | |
| 3684 | + | |
| 3685 | + | |
| 3686 | + | |
| 3687 | + | |
| 3688 | + | |
| 3689 | + | |
| 3690 | + | |
| 3691 | + | |
| 3692 | + | |
| 3693 | + | |
| 3694 | + | |
| 3695 | + | |
| 3696 | + | |
| 3697 | + | |
| 3698 | + | |
| 3699 | + | |
| 3700 | + | |
| 3701 | + | |
| 3702 | + | |
| 3703 | + | |
| 3704 | + | |
| 3705 | + | |
| 3706 | + | |
| 3707 | + | |
| 3708 | + | |
| 3709 | + | |
| 3710 | + | |
| 3711 | + | |
| 3712 | + | |
| 3713 | + | |
| 3714 | + | |
| 3715 | + | |
| 3716 | + | |
| 3717 | + | |
| 3718 | + | |
| 3719 | + | |
| 3720 | + | |
| 3721 | + | |
| 3722 | + | |
| 3723 | + | |
| 3724 | + | |
| 3725 | + | |
| 3726 | + | |
| 3727 | + | |
| 3728 | + | |
| 3729 | + | |
| 3730 | + | |
| 3731 | + | |
| 3732 | + | |
| 3733 | + | |
| 3734 | + | |
| 3735 | + | |
| 3736 | + | |
| 3737 | + | |
| 3738 | + | |
| 3739 | + | |
| 3740 | + | |
| 3741 | + | |
| 3742 | + | |
| 3743 | + | |
| 3744 | + | |
| 3745 | + | |
| 3746 | + | |
| 3747 | + | |
| 3748 | + | |
| 3749 | + | |
| 3750 | + | |
| 3751 | + | |
| 3752 | + | |
| 3753 | + | |
| 3754 | + | |
| 3755 | + | |
| 3756 | + | |
| 3757 | + | |
| 3758 | + | |
| 3759 | + | |
| 3760 | + | |
| 3761 | + | |
| 3762 | + | |
| 3763 | + | |
| 3764 | + | |
| 3765 | + | |
| 3766 | + | |
| 3767 | + | |
| 3768 | + | |
| 3769 | + | |
| 3770 | + | |
| 3771 | + | |
| 3772 | + | |
| 3773 | + | |
| 3774 | + | |
| 3775 | + | |
| 3776 | + | |
| 3777 | + | |
| 3778 | + | |
| 3779 | + | |
| 3780 | + | |
| 3781 | + | |
| 3782 | + | |
| 3783 | + | |
| 3784 | + | |
| 3785 | + | |
| 3786 | + | |
| 3787 | + | |
| 3788 | + | |
| 3789 | + | |
| 3790 | + | |
| 3791 | + | |
| 3792 | + | |
| 3793 | + | |
| 3794 | + | |
| 3795 | + | |
| 3796 | + | |
| 3797 | + | |
| 3798 | + | |
| 3799 | + | |
| 3800 | + | |
| 3801 | + | |
| 3802 | + | |
| 3803 | + | |
| 3804 | + | |
| 3805 | + | |
| 3806 | + | |
| 3807 | + | |
| 3808 | + | |
| 3809 | + | |
| 3810 | + | |
| 3811 | + | |
| 3812 | + | |
| 3813 | + | |
| 3814 | + | |
| 3815 | + | |
| 3816 | + | |
| 3817 | + | |
3712 | 3818 | | |
3713 | 3819 | | |
3714 | 3820 | | |
| |||
3763 | 3869 | | |
3764 | 3870 | | |
3765 | 3871 | | |
3766 | | - | |
3767 | | - | |
3768 | | - | |
3769 | | - | |
3770 | | - | |
3771 | | - | |
3772 | | - | |
3773 | | - | |
3774 | | - | |
3775 | 3872 | | |
3776 | 3873 | | |
3777 | 3874 | | |
3778 | | - | |
| 3875 | + | |
3779 | 3876 | | |
3780 | 3877 | | |
3781 | 3878 | | |
3782 | | - | |
| 3879 | + | |
3783 | 3880 | | |
3784 | 3881 | | |
3785 | 3882 | | |
| |||
3832 | 3929 | | |
3833 | 3930 | | |
3834 | 3931 | | |
3835 | | - | |
3836 | | - | |
| 3932 | + | |
| 3933 | + | |
3837 | 3934 | | |
3838 | 3935 | | |
3839 | 3936 | | |
3840 | 3937 | | |
3841 | | - | |
3842 | | - | |
| 3938 | + | |
| 3939 | + | |
3843 | 3940 | | |
3844 | 3941 | | |
3845 | 3942 | | |
| |||
12909 | 13006 | | |
12910 | 13007 | | |
12911 | 13008 | | |
| 13009 | + | |
12912 | 13010 | | |
12913 | 13011 | | |
12914 | 13012 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
421 | 421 | | |
422 | 422 | | |
423 | 423 | | |
| 424 | + | |
424 | 425 | | |
425 | 426 | | |
426 | 427 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
327 | 327 | | |
328 | 328 | | |
329 | 329 | | |
330 | | - | |
| 330 | + | |
| 331 | + | |
331 | 332 | | |
332 | 333 | | |
333 | 334 | | |
| |||
0 commit comments