Skip to content

Commit e03f7c6

Browse files
authored
Decompose aten.channel_shuffle op (#4243) (#4259)
Support for the channel shuffle operator is added by torch dialect level decomposition (similar to the pixel_shuffle operation). The decomposition is based on this specification: https://docs.pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html and implementation: aten/src/ATen/native/ChanelShuffle.cpp https://github.com/pytorch/pytorch/blob/23491519d288dedb2a54cfad5fef7fcb2ad8eade/aten/src/ATen/native/ChanelShuffle.cpp#L4
1 parent 3eb2475 commit e03f7c6

File tree

10 files changed

+435
-43
lines changed

10 files changed

+435
-43
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
86688668
}];
86698669
}
86708670

8671+
def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
8672+
AllowsTypeRefinement,
8673+
HasValueSemantics,
8674+
ReadOnly
8675+
]> {
8676+
let summary = "Generated op for `aten::channel_shuffle : (Tensor, int) -> (Tensor)`";
8677+
let arguments = (ins
8678+
AnyTorchTensorType:$self,
8679+
Torch_IntType:$groups
8680+
);
8681+
let results = (outs
8682+
AnyTorchOptionalTensorType:$result
8683+
);
8684+
let hasCustomAssemblyFormat = 1;
8685+
let extraClassDefinition = [{
8686+
ParseResult AtenChannelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
8687+
return parseDefaultTorchOp(parser, result, 2, 1);
8688+
}
8689+
void AtenChannelShuffleOp::print(OpAsmPrinter &printer) {
8690+
printDefaultTorchOp(printer, *this, 2, 1);
8691+
}
8692+
}];
8693+
}
8694+
86718695
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
86728696
AllowsTypeRefinement,
86738697
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7613,6 +7613,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
76137613
" %15 = torch.aten.append.t %6, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
76147614
" return %6 : !torch.list<int>\n"
76157615
" }\n"
7616+
" func.func @\"__torch_mlir_shape_fn.aten.channel_shuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
7617+
" %none = torch.constant.none\n"
7618+
" %str = torch.constant.str \"AssertionError: input must be at least rank-3 in channel_shuffle\"\n"
7619+
" %int3 = torch.constant.int 3\n"
7620+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
7621+
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
7622+
" torch.prim.If %1 -> () {\n"
7623+
" torch.prim.If.yield\n"
7624+
" } else {\n"
7625+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
7626+
" torch.prim.If.yield\n"
7627+
" }\n"
7628+
" return %arg0 : !torch.list<int>\n"
7629+
" }\n"
76167630
" func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
76177631
" %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
76187632
" return %0 : !torch.list<int>\n"
@@ -12305,6 +12319,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1230512319
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1230612320
" return %0#1 : !torch.int\n"
1230712321
" }\n"
12322+
" func.func @\"__torch_mlir_dtype_fn.aten.channel_shuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12323+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12324+
" return %0#1 : !torch.int\n"
12325+
" }\n"
1230812326
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
1230912327
" %none = torch.constant.none\n"
1231012328
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 175 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3537,6 +3537,30 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
35373537
};
35383538
} // namespace
35393539

3540+
namespace { // Start of rearrangement ops utility functions
3541+
// Extracts shape as vector of int64_t from vector of Value
3542+
SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3543+
SmallVector<int64_t> shape;
3544+
shape.reserve(vals.size());
3545+
for (Value v : vals) {
3546+
int64_t cst_val;
3547+
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3548+
shape.push_back(cst_val);
3549+
} else {
3550+
shape.push_back(kUnknownSize);
3551+
}
3552+
}
3553+
return shape;
3554+
}
3555+
3556+
// Converts a vector of Value (shape dimensions) into a ValueTensorType
3557+
ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3558+
SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3559+
return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3560+
inOptionalDType);
3561+
}
3562+
} // namespace
3563+
35403564
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
35413565
// prims.collapse operations.
35423566
//
@@ -3562,7 +3586,6 @@ class DecomposeAtenPixelShuffleOp
35623586
using OpRewritePattern::OpRewritePattern;
35633587
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
35643588
PatternRewriter &rewriter) const override {
3565-
35663589
Location loc = op.getLoc();
35673590
Value inValue = op.getSelf();
35683591
auto inType = cast<BaseTensorType>(inValue.getType());
@@ -3585,27 +3608,6 @@ class DecomposeAtenPixelShuffleOp
35853608

35863609
const auto inOptionalDType = inType.getOptionalDtype();
35873610

3588-
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3589-
// Get a vector of integers from a vector of Values.
3590-
auto getIntShape = [](auto &&vals) {
3591-
SmallVector<int64_t> shape;
3592-
shape.reserve(vals.size());
3593-
for (auto v : vals) {
3594-
int64_t cst_val;
3595-
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3596-
shape.push_back(cst_val);
3597-
} else {
3598-
shape.push_back(kUnknownSize);
3599-
}
3600-
}
3601-
return shape;
3602-
};
3603-
3604-
const auto intShape = getIntShape(vals);
3605-
return ValueTensorType::get(vals[0].getContext(),
3606-
llvm::ArrayRef(intShape), inOptionalDType);
3607-
};
3608-
36093611
auto nLeadingDims = inRank - 3;
36103612

36113613
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
@@ -3677,24 +3679,24 @@ class DecomposeAtenPixelShuffleOp
36773679
auto partiallyExpanded =
36783680
rewriter
36793681
.create<PrimsSplitDimOp>(
3680-
loc, getTypeFromShape(partiallyExpandedShape), inValue,
3681-
dimensionConstants[nLeadingDims], outC)
3682+
loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType),
3683+
inValue, dimensionConstants[nLeadingDims], outC)
36823684
.getResult();
36833685

36843686
// Split new dimension factorSquared -> (factor, factor)
36853687
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3686-
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
3687-
dimensionConstants[nLeadingDims + 1], factor);
3688+
loc, getTypeFromShape(prePermuteShape, inOptionalDType),
3689+
partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
36883690

36893691
// Perform the permutation
3690-
auto permuted =
3691-
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
3692-
fullyExpanded, permuteDimsOrder);
3692+
auto permuted = rewriter.create<AtenPermuteOp>(
3693+
loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded,
3694+
permuteDimsOrder);
36933695

36943696
// Collapse final 2 dimension
36953697
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3696-
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
3697-
dimensionConstants[nLeadingDims + 3],
3698+
loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType),
3699+
permuted, dimensionConstants[nLeadingDims + 3],
36983700
dimensionConstants[nLeadingDims + 4]);
36993701

37003702
// Collapse back to original rank
@@ -3708,6 +3710,147 @@ class DecomposeAtenPixelShuffleOp
37083710
};
37093711
} // namespace
37103712

3713+
// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
3714+
// prims.collapse operations.
3715+
//
3716+
// If input is a tensor of shape
3717+
// (N, g*C, H, W),
3718+
//
3719+
// then
3720+
// X = channel_shuffle(input, groups)
3721+
//
3722+
// gets replaced with
3723+
// X = input.split_dim(...) # shape (N, g, C, *)
3724+
// X = X.permute(0, 2, 1, ...) # shape (N, C, g, *)
3725+
// X = X.collapse(...) # shape (N, C*g, *)
3726+
//
3727+
// 'g' above is referred to as the number of 'groups'. N is the batch
3728+
// dimension, and can't be omitted. In PyTorch's ChannelShuffle operator
3729+
// if the batch dimension is ommitted, the first spatial dimenion is seen
3730+
// as the channel. PyTorch errors out for the code below indicating that
3731+
// 4 is not divisible by 3:
3732+
// input_tensor = torch.arange(1, 37, dtype=torch.float32).view(3, 4, 3)
3733+
// channel_shuffle_layer = nn.ChannelShuffle(groups=3)
3734+
// output_tensor = channel_shuffle_layer(input_tensor)
3735+
//
3736+
// The decomposition is based on this specification:
3737+
// https://pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
3738+
// and PyTorch implementation: aten/src/ATen/native/ChanelShuffle.cpp
3739+
// (yes, the filename is misspelled "Chanel" in upstream PyTorch)
3740+
//
3741+
namespace {
3742+
class DecomposeAtenChannelShuffleOp
3743+
: public OpRewritePattern<AtenChannelShuffleOp> {
3744+
public:
3745+
using OpRewritePattern::OpRewritePattern;
3746+
LogicalResult matchAndRewrite(AtenChannelShuffleOp op,
3747+
PatternRewriter &rewriter) const override {
3748+
Location loc = op.getLoc();
3749+
Value inValue = op.getSelf();
3750+
auto inType = cast<BaseTensorType>(inValue.getType());
3751+
auto maybeSizes = inType.getOptionalSizes();
3752+
if (!maybeSizes) {
3753+
return rewriter.notifyMatchFailure(
3754+
op, "Expected input tensor to have known rank.");
3755+
}
3756+
auto inShape = maybeSizes.value();
3757+
auto inRank = inShape.size();
3758+
3759+
// The input tensor must have at least 3 dimensions: batch size,
3760+
// channel size, and at least one spatial dimension.
3761+
if (inRank < 3)
3762+
return rewriter.notifyMatchFailure(
3763+
op, "Expected input tensor to have rank greater than or equal to 3.");
3764+
3765+
auto numOfSpatialDims = inRank - 2;
3766+
3767+
// Get the size of the dimension 'i'. Note the use of 'createOrFold'
3768+
// instead of 'create': if the dimension size is known, then the
3769+
// AtenSizeIntOp is folded to a ConstantOp.
3770+
auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3771+
Value dim =
3772+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3773+
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3774+
};
3775+
3776+
// The channel dimension is always the second dimension. PyTorch errors out
3777+
// if the batch dimension (first dimension) is not present. See comment at
3778+
// the top of this class for details.
3779+
auto inC = getDimSize(1);
3780+
SmallVector<Value> inSpatialDims;
3781+
inSpatialDims.reserve(numOfSpatialDims);
3782+
for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3783+
inSpatialDims.push_back(getDimSize(i));
3784+
}
3785+
3786+
auto groups = op.getGroups();
3787+
3788+
// Temporary channel dimension size: tempC = inC / groups
3789+
// Assumes input has been validated: `inC % groups == 0`
3790+
// This is enforced by PyTorch's runtime and is required for correctness.
3791+
Value tempC = rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, groups);
3792+
3793+
// Create constants for split/permute/collapse operations. Note that we
3794+
// need an extra constant for the channel dimension split.
3795+
SmallVector<Value> dimensionConstants;
3796+
dimensionConstants.reserve(inRank + 1);
3797+
for (unsigned i = 0; i < inRank + 1; ++i) {
3798+
dimensionConstants.push_back(
3799+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3800+
}
3801+
3802+
Value batchDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3803+
loc, inValue, dimensionConstants[0]);
3804+
3805+
SmallVector<Value> splitShape;
3806+
splitShape.reserve(inRank + 1);
3807+
splitShape.append({batchDimSize, groups, tempC});
3808+
splitShape.append(inSpatialDims); // Appends all spatial dimensions
3809+
3810+
SmallVector<Value> permuteShape;
3811+
permuteShape.reserve(inRank + 1);
3812+
permuteShape.append({batchDimSize, tempC, groups});
3813+
permuteShape.append(inSpatialDims); // Appends all spatial dimensions
3814+
3815+
// Permute (N, groups, tempC, *) -> (N, tempC, groups, *)
3816+
SmallVector<Value> permutation{dimensionConstants[0], // batch dimension
3817+
dimensionConstants[2], // tempC
3818+
dimensionConstants[1]}; // groups
3819+
for (unsigned i = 3; i < inRank + 1; ++i) {
3820+
permutation.push_back(dimensionConstants[i]);
3821+
}
3822+
3823+
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3824+
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3825+
permutation);
3826+
3827+
const auto inOptionalDType = inType.getOptionalDtype();
3828+
3829+
Value dimC = dimensionConstants[1];
3830+
Value dimG = dimensionConstants[2];
3831+
3832+
// Split input channel inC -> (groups, inC/groups)
3833+
auto expandedTensor =
3834+
rewriter
3835+
.create<PrimsSplitDimOp>(
3836+
loc, getTypeFromShape(splitShape, inOptionalDType), inValue,
3837+
dimC, tempC)
3838+
.getResult();
3839+
3840+
// Perform the permutation
3841+
auto permuted = rewriter.create<AtenPermuteOp>(
3842+
loc, getTypeFromShape(permuteShape, inOptionalDType), expandedTensor,
3843+
permuteDimsOrder);
3844+
3845+
// Collapse (C, groups) back into a single channel dimension
3846+
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
3847+
dimC, dimG);
3848+
3849+
return success();
3850+
}
3851+
};
3852+
} // namespace
3853+
37113854
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
37123855
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
37133856
Value input) {
@@ -12518,6 +12661,7 @@ class DecomposeComplexOpsPass
1251812661
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1251912662
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1252012663
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
12664+
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
1252112665
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1252212666
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
1252312667
patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
421421
target.addIllegalOp<Aten_LinalgDetOp>();
422422
target.addIllegalOp<AtenLinalgSlogdetOp>();
423423
target.addIllegalOp<AtenPixelShuffleOp>();
424+
target.addIllegalOp<AtenChannelShuffleOp>();
424425
target.addIllegalOp<AtenTOp>();
425426
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
426427
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,17 @@ bool Torch::isViewLikeOp(Operation *op) {
317317
// correct. We could potentially be more precise and identify the cases
318318
// that it does not return a view and treat those as having value
319319
// semantics.
320-
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
321-
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
322-
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
323-
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
324-
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
325-
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
326-
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
327-
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
328-
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
329-
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
330-
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
320+
return isa<
321+
AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
322+
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
323+
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
324+
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
325+
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
326+
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp,
327+
AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
328+
PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp,
329+
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp,
330+
AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
331331
}
332332

333333
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

0 commit comments

Comments
 (0)