@@ -3537,6 +3537,30 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
3537
3537
};
3538
3538
} // namespace
3539
3539
3540
+ namespace { // Start of rearrangement ops utility functions
3541
+ // Extracts shape as vector of int64_t from vector of Value
3542
+ SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3543
+ SmallVector<int64_t> shape;
3544
+ shape.reserve(vals.size());
3545
+ for (Value v : vals) {
3546
+ int64_t cst_val;
3547
+ if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3548
+ shape.push_back(cst_val);
3549
+ } else {
3550
+ shape.push_back(kUnknownSize);
3551
+ }
3552
+ }
3553
+ return shape;
3554
+ }
3555
+
3556
+ // Converts a vector of Value (shape dimensions) into a ValueTensorType
3557
+ ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3558
+ SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3559
+ return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3560
+ inOptionalDType);
3561
+ }
3562
+ } // namespace
3563
+
3540
3564
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
3541
3565
// prims.collapse operations.
3542
3566
//
@@ -3562,7 +3586,6 @@ class DecomposeAtenPixelShuffleOp
3562
3586
using OpRewritePattern::OpRewritePattern;
3563
3587
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
3564
3588
PatternRewriter &rewriter) const override {
3565
-
3566
3589
Location loc = op.getLoc();
3567
3590
Value inValue = op.getSelf();
3568
3591
auto inType = cast<BaseTensorType>(inValue.getType());
@@ -3585,27 +3608,6 @@ class DecomposeAtenPixelShuffleOp
3585
3608
3586
3609
const auto inOptionalDType = inType.getOptionalDtype();
3587
3610
3588
- auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3589
- // Get a vector of integers from a vector of Values.
3590
- auto getIntShape = [](auto &&vals) {
3591
- SmallVector<int64_t> shape;
3592
- shape.reserve(vals.size());
3593
- for (auto v : vals) {
3594
- int64_t cst_val;
3595
- if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3596
- shape.push_back(cst_val);
3597
- } else {
3598
- shape.push_back(kUnknownSize);
3599
- }
3600
- }
3601
- return shape;
3602
- };
3603
-
3604
- const auto intShape = getIntShape(vals);
3605
- return ValueTensorType::get(vals[0].getContext(),
3606
- llvm::ArrayRef(intShape), inOptionalDType);
3607
- };
3608
-
3609
3611
auto nLeadingDims = inRank - 3;
3610
3612
3611
3613
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
@@ -3677,24 +3679,24 @@ class DecomposeAtenPixelShuffleOp
3677
3679
auto partiallyExpanded =
3678
3680
rewriter
3679
3681
.create<PrimsSplitDimOp>(
3680
- loc, getTypeFromShape(partiallyExpandedShape), inValue ,
3681
- dimensionConstants[nLeadingDims], outC)
3682
+ loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType) ,
3683
+ inValue, dimensionConstants[nLeadingDims], outC)
3682
3684
.getResult();
3683
3685
3684
3686
// Split new dimension factorSquared -> (factor, factor)
3685
3687
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3686
- loc, getTypeFromShape(prePermuteShape), partiallyExpanded ,
3687
- dimensionConstants[nLeadingDims + 1], factor);
3688
+ loc, getTypeFromShape(prePermuteShape, inOptionalDType) ,
3689
+ partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
3688
3690
3689
3691
// Perform the permutation
3690
- auto permuted =
3691
- rewriter.create<AtenPermuteOp>( loc, getTypeFromShape(postPermuteShape) ,
3692
- fullyExpanded, permuteDimsOrder);
3692
+ auto permuted = rewriter.create<AtenPermuteOp>(
3693
+ loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded ,
3694
+ permuteDimsOrder);
3693
3695
3694
3696
// Collapse final 2 dimension
3695
3697
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3696
- loc, getTypeFromShape(partiallyCollapsedShape), permuted ,
3697
- dimensionConstants[nLeadingDims + 3],
3698
+ loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType) ,
3699
+ permuted, dimensionConstants[nLeadingDims + 3],
3698
3700
dimensionConstants[nLeadingDims + 4]);
3699
3701
3700
3702
// Collapse back to original rank
@@ -3708,6 +3710,147 @@ class DecomposeAtenPixelShuffleOp
3708
3710
};
3709
3711
} // namespace
3710
3712
3713
+ // Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
3714
+ // prims.collapse operations.
3715
+ //
3716
+ // If input is a tensor of shape
3717
+ // (N, g*C, H, W),
3718
+ //
3719
+ // then
3720
+ // X = channel_shuffle(input, groups)
3721
+ //
3722
+ // gets replaced with
3723
+ // X = input.split_dim(...) # shape (N, g, C, *)
3724
+ // X = X.permute(0, 2, 1, ...) # shape (N, C, g, *)
3725
+ // X = X.collapse(...) # shape (N, C*g, *)
3726
+ //
3727
+ // 'g' above is referred to as the number of 'groups'. N is the batch
3728
+ // dimension, and can't be omitted. In PyTorch's ChannelShuffle operator
3729
+ // if the batch dimension is ommitted, the first spatial dimenion is seen
3730
+ // as the channel. PyTorch errors out for the code below indicating that
3731
+ // 4 is not divisible by 3:
3732
+ // input_tensor = torch.arange(1, 37, dtype=torch.float32).view(3, 4, 3)
3733
+ // channel_shuffle_layer = nn.ChannelShuffle(groups=3)
3734
+ // output_tensor = channel_shuffle_layer(input_tensor)
3735
+ //
3736
+ // The decomposition is based on this specification:
3737
+ // https://pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
3738
+ // and PyTorch implementation: aten/src/ATen/native/ChanelShuffle.cpp
3739
+ // (yes, the filename is misspelled "Chanel" in upstream PyTorch)
3740
+ //
3741
+ namespace {
3742
+ class DecomposeAtenChannelShuffleOp
3743
+ : public OpRewritePattern<AtenChannelShuffleOp> {
3744
+ public:
3745
+ using OpRewritePattern::OpRewritePattern;
3746
+ LogicalResult matchAndRewrite(AtenChannelShuffleOp op,
3747
+ PatternRewriter &rewriter) const override {
3748
+ Location loc = op.getLoc();
3749
+ Value inValue = op.getSelf();
3750
+ auto inType = cast<BaseTensorType>(inValue.getType());
3751
+ auto maybeSizes = inType.getOptionalSizes();
3752
+ if (!maybeSizes) {
3753
+ return rewriter.notifyMatchFailure(
3754
+ op, "Expected input tensor to have known rank.");
3755
+ }
3756
+ auto inShape = maybeSizes.value();
3757
+ auto inRank = inShape.size();
3758
+
3759
+ // The input tensor must have at least 3 dimensions: batch size,
3760
+ // channel size, and at least one spatial dimension.
3761
+ if (inRank < 3)
3762
+ return rewriter.notifyMatchFailure(
3763
+ op, "Expected input tensor to have rank greater than or equal to 3.");
3764
+
3765
+ auto numOfSpatialDims = inRank - 2;
3766
+
3767
+ // Get the size of the dimension 'i'. Note the use of 'createOrFold'
3768
+ // instead of 'create': if the dimension size is known, then the
3769
+ // AtenSizeIntOp is folded to a ConstantOp.
3770
+ auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3771
+ Value dim =
3772
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3773
+ return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3774
+ };
3775
+
3776
+ // The channel dimension is always the second dimension. PyTorch errors out
3777
+ // if the batch dimension (first dimension) is not present. See comment at
3778
+ // the top of this class for details.
3779
+ auto inC = getDimSize(1);
3780
+ SmallVector<Value> inSpatialDims;
3781
+ inSpatialDims.reserve(numOfSpatialDims);
3782
+ for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3783
+ inSpatialDims.push_back(getDimSize(i));
3784
+ }
3785
+
3786
+ auto groups = op.getGroups();
3787
+
3788
+ // Temporary channel dimension size: tempC = inC / groups
3789
+ // Assumes input has been validated: `inC % groups == 0`
3790
+ // This is enforced by PyTorch's runtime and is required for correctness.
3791
+ Value tempC = rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, groups);
3792
+
3793
+ // Create constants for split/permute/collapse operations. Note that we
3794
+ // need an extra constant for the channel dimension split.
3795
+ SmallVector<Value> dimensionConstants;
3796
+ dimensionConstants.reserve(inRank + 1);
3797
+ for (unsigned i = 0; i < inRank + 1; ++i) {
3798
+ dimensionConstants.push_back(
3799
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3800
+ }
3801
+
3802
+ Value batchDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3803
+ loc, inValue, dimensionConstants[0]);
3804
+
3805
+ SmallVector<Value> splitShape;
3806
+ splitShape.reserve(inRank + 1);
3807
+ splitShape.append({batchDimSize, groups, tempC});
3808
+ splitShape.append(inSpatialDims); // Appends all spatial dimensions
3809
+
3810
+ SmallVector<Value> permuteShape;
3811
+ permuteShape.reserve(inRank + 1);
3812
+ permuteShape.append({batchDimSize, tempC, groups});
3813
+ permuteShape.append(inSpatialDims); // Appends all spatial dimensions
3814
+
3815
+ // Permute (N, groups, tempC, *) -> (N, tempC, groups, *)
3816
+ SmallVector<Value> permutation{dimensionConstants[0], // batch dimension
3817
+ dimensionConstants[2], // tempC
3818
+ dimensionConstants[1]}; // groups
3819
+ for (unsigned i = 3; i < inRank + 1; ++i) {
3820
+ permutation.push_back(dimensionConstants[i]);
3821
+ }
3822
+
3823
+ Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3824
+ loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3825
+ permutation);
3826
+
3827
+ const auto inOptionalDType = inType.getOptionalDtype();
3828
+
3829
+ Value dimC = dimensionConstants[1];
3830
+ Value dimG = dimensionConstants[2];
3831
+
3832
+ // Split input channel inC -> (groups, inC/groups)
3833
+ auto expandedTensor =
3834
+ rewriter
3835
+ .create<PrimsSplitDimOp>(
3836
+ loc, getTypeFromShape(splitShape, inOptionalDType), inValue,
3837
+ dimC, tempC)
3838
+ .getResult();
3839
+
3840
+ // Perform the permutation
3841
+ auto permuted = rewriter.create<AtenPermuteOp>(
3842
+ loc, getTypeFromShape(permuteShape, inOptionalDType), expandedTensor,
3843
+ permuteDimsOrder);
3844
+
3845
+ // Collapse (C, groups) back into a single channel dimension
3846
+ rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
3847
+ dimC, dimG);
3848
+
3849
+ return success();
3850
+ }
3851
+ };
3852
+ } // namespace
3853
+
3711
3854
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
3712
3855
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
3713
3856
Value input) {
@@ -12518,6 +12661,7 @@ class DecomposeComplexOpsPass
12518
12661
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
12519
12662
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
12520
12663
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
12664
+ addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
12521
12665
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
12522
12666
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
12523
12667
patterns);
0 commit comments