@@ -3710,6 +3710,177 @@ class DecomposeAtenPixelShuffleOp
3710
3710
};
3711
3711
} // namespace
3712
3712
3713
+ // Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3714
+ // prims.collapse operations.
3715
+ //
3716
+ // We want to do the exact opposite of aten.pixel_shuffle
3717
+ //
3718
+ // If input is a tensor of shape
3719
+ // (*leading_dims, C, H*r, W*r),
3720
+ //
3721
+ // where leading_dims is of size N, then
3722
+ // X = pixel_unshuffle(input, downscale_factor)
3723
+ //
3724
+ // gets replaced with
3725
+ // X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3726
+ // X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3727
+ // X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3728
+ // # shape (*leading_dims, C, r, r, H, W)
3729
+ // X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W)
3730
+ // X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3731
+ //
3732
+ // 'r' above is referred to as the 'downscale factor' or just 'factor' below.
3733
+ namespace {
3734
+ class DecomposeAtenPixelUnshuffleOp
3735
+ : public OpRewritePattern<AtenPixelUnshuffleOp> {
3736
+ public:
3737
+ using OpRewritePattern::OpRewritePattern;
3738
+ LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3739
+ PatternRewriter &rewriter) const override {
3740
+
3741
+ Location loc = op.getLoc();
3742
+ Value inValue = op.getSelf();
3743
+ auto inType = cast<BaseTensorType>(inValue.getType());
3744
+ auto maybeSizes = inType.getOptionalSizes();
3745
+ if (!maybeSizes) {
3746
+ return rewriter.notifyMatchFailure(
3747
+ op, "Expected input tensor to have known rank.");
3748
+ }
3749
+ auto inShape = maybeSizes.value();
3750
+ auto inRank = inShape.size();
3751
+
3752
+ // The input tensor must have at least 3 dimensions: (1) the channel
3753
+ // dimension which gets bigger by 'factor*factor', (2) the H channel which
3754
+ // gets smaller by 'factor' and (3) the W channel which get smaller by
3755
+ // 'factor'. The total number of dimensions is 3 + N, where N is the number
3756
+ // of leading dimensions, and N >= 0 so the input must have rank at least 3.
3757
+ if (inRank < 3)
3758
+ return rewriter.notifyMatchFailure(
3759
+ op, "Expected input tensor to have rank greater than 2.");
3760
+
3761
+ const auto inOptionalDType = inType.getOptionalDtype();
3762
+
3763
+ auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3764
+ // Get a vector of integers from a vector of Values.
3765
+ auto getIntShape = [](auto &&vals) {
3766
+ SmallVector<int64_t> shape;
3767
+ shape.reserve(vals.size());
3768
+ for (auto v : vals) {
3769
+ int64_t cst_val;
3770
+ if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3771
+ shape.push_back(cst_val);
3772
+ } else {
3773
+ shape.push_back(kUnknownSize);
3774
+ }
3775
+ }
3776
+ return shape;
3777
+ };
3778
+
3779
+ const auto intShape = getIntShape(vals);
3780
+ return ValueTensorType::get(vals[0].getContext(),
3781
+ llvm::ArrayRef(intShape), inOptionalDType);
3782
+ };
3783
+
3784
+ auto nLeadingDims = inRank - 3;
3785
+
3786
+ // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3787
+ // of 'create': if the dimension size is known, then the AtenSizeIntOp is
3788
+ // folded to a ConstantOp.
3789
+ auto getDimSize = [&](uint64_t i) -> Value {
3790
+ Value dim =
3791
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3792
+ return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3793
+ };
3794
+
3795
+ auto inC = getDimSize(inRank - 3);
3796
+ auto inH = getDimSize(inRank - 2);
3797
+ auto inW = getDimSize(inRank - 1);
3798
+
3799
+ auto factor = op.getDownscaleFactor();
3800
+
3801
+ Value factorSquared =
3802
+ rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3803
+
3804
+ Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3805
+
3806
+ Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3807
+ Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3808
+
3809
+ SmallVector<Value> dimensionConstants;
3810
+ dimensionConstants.reserve(inRank + 2);
3811
+ for (unsigned i = 0; i < inRank + 2; ++i) {
3812
+ dimensionConstants.push_back(
3813
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3814
+ }
3815
+
3816
+ SmallVector<Value> leadingDims;
3817
+ leadingDims.reserve(nLeadingDims);
3818
+ for (unsigned i = 0; i < nLeadingDims; ++i) {
3819
+ Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3820
+ loc, inValue, dimensionConstants[i]);
3821
+ leadingDims.push_back(leadingDimSize);
3822
+ }
3823
+
3824
+ SmallVector<Value> partiallyExpandedShape = leadingDims;
3825
+ partiallyExpandedShape.append({inC, outH, factor, inW});
3826
+
3827
+ SmallVector<Value> prePermuteShape = leadingDims;
3828
+ prePermuteShape.append({inC, outH, factor, outW, factor});
3829
+
3830
+ SmallVector<Value> postPermuteShape = leadingDims;
3831
+ postPermuteShape.append({inC, factor, factor, outH, outW});
3832
+
3833
+ SmallVector<Value> partiallyCollapsedShape = leadingDims;
3834
+ partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3835
+
3836
+ SmallVector<Value> outShape = leadingDims;
3837
+ outShape.append({outC, outH, outW});
3838
+
3839
+ SmallVector<Value> permutation{dimensionConstants.begin(),
3840
+ dimensionConstants.begin() + nLeadingDims};
3841
+ SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3842
+ for (uint64_t d : permutationTail) {
3843
+ permutation.push_back(dimensionConstants[nLeadingDims + d]);
3844
+ }
3845
+
3846
+ Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3847
+ loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3848
+ permutation);
3849
+
3850
+ // Split input channel inH -> (outH, factor)
3851
+ auto partiallyExpanded =
3852
+ rewriter
3853
+ .create<PrimsSplitDimOp>(
3854
+ loc, getTypeFromShape(partiallyExpandedShape), inValue,
3855
+ dimensionConstants[nLeadingDims + 1], outH)
3856
+ .getResult();
3857
+
3858
+ // Split new dimension inW -> (outW, factor)
3859
+ auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3860
+ loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
3861
+ dimensionConstants[nLeadingDims + 3], outW);
3862
+
3863
+ // Perform the permutation
3864
+ auto permuted =
3865
+ rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
3866
+ fullyExpanded, permuteDimsOrder);
3867
+
3868
+ // Collapse final 2 dimension
3869
+ auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3870
+ loc, getTypeFromShape(partiallyCollapsedShape), permuted,
3871
+ dimensionConstants[nLeadingDims + 1],
3872
+ dimensionConstants[nLeadingDims + 2]);
3873
+
3874
+ // Collapse back to original rank
3875
+ rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3876
+ op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims],
3877
+ dimensionConstants[nLeadingDims + 1]);
3878
+
3879
+ return success();
3880
+ }
3881
+ };
3882
+ } // namespace
3883
+
3713
3884
// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
3714
3885
// prims.collapse operations.
3715
3886
//
@@ -12859,6 +13030,7 @@ class DecomposeComplexOpsPass
12859
13030
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
12860
13031
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
12861
13032
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
13033
+ addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
12862
13034
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
12863
13035
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
12864
13036
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
0 commit comments