Skip to content

Commit 39a8540

Browse files
committed
add lowering torch.aten.pixel_unshuffle op to linalg
1 parent b1053f8 commit 39a8540

File tree

9 files changed

+400
-1
lines changed

9 files changed

+400
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
86688668
}];
86698669
}
86708670

8671+
def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [
8672+
AllowsTypeRefinement,
8673+
HasValueSemantics,
8674+
ReadOnly
8675+
]> {
8676+
let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`";
8677+
let arguments = (ins
8678+
AnyTorchTensorType:$self,
8679+
Torch_IntType:$downscale_factor
8680+
);
8681+
let results = (outs
8682+
AnyTorchOptionalTensorType:$result
8683+
);
8684+
let hasCustomAssemblyFormat = 1;
8685+
let extraClassDefinition = [{
8686+
ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) {
8687+
return parseDefaultTorchOp(parser, result, 2, 1);
8688+
}
8689+
void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) {
8690+
printDefaultTorchOp(printer, *this, 2, 1);
8691+
}
8692+
}];
8693+
}
8694+
86718695
def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
86728696
AllowsTypeRefinement,
86738697
HasValueSemantics,

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3710,6 +3710,177 @@ class DecomposeAtenPixelShuffleOp
37103710
};
37113711
} // namespace
37123712

3713+
// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3714+
// prims.collapse operations.
3715+
//
3716+
// We want to do the exact opposite of aten.pixel_shuffle
3717+
//
3718+
// If input is a tensor of shape
3719+
// (*leading_dims, C, H*r, W*r),
3720+
//
3721+
// where leading_dims is of size N, then
3722+
// X = pixel_unshuffle(input, downscale_factor)
3723+
//
3724+
// gets replaced with
3725+
// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3726+
// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3727+
// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3728+
// # shape (*leading_dims, C, r, r, H, W)
3729+
// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W)
3730+
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3731+
//
3732+
// 'r' above is referred to as the 'downscale factor' or just 'factor' below.
3733+
namespace {
3734+
class DecomposeAtenPixelUnshuffleOp
3735+
: public OpRewritePattern<AtenPixelUnshuffleOp> {
3736+
public:
3737+
using OpRewritePattern::OpRewritePattern;
3738+
LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3739+
PatternRewriter &rewriter) const override {
3740+
3741+
Location loc = op.getLoc();
3742+
Value inValue = op.getSelf();
3743+
auto inType = cast<BaseTensorType>(inValue.getType());
3744+
auto maybeSizes = inType.getOptionalSizes();
3745+
if (!maybeSizes) {
3746+
return rewriter.notifyMatchFailure(
3747+
op, "Expected input tensor to have known rank.");
3748+
}
3749+
auto inShape = maybeSizes.value();
3750+
auto inRank = inShape.size();
3751+
3752+
// The input tensor must have at least 3 dimensions: (1) the channel
3753+
// dimension which gets bigger by 'factor*factor', (2) the H channel which
3754+
// gets smaller by 'factor' and (3) the W channel which get smaller by
3755+
// 'factor'. The total number of dimensions is 3 + N, where N is the number
3756+
// of leading dimensions, and N >= 0 so the input must have rank at least 3.
3757+
if (inRank < 3)
3758+
return rewriter.notifyMatchFailure(
3759+
op, "Expected input tensor to have rank greater than 2.");
3760+
3761+
const auto inOptionalDType = inType.getOptionalDtype();
3762+
3763+
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3764+
// Get a vector of integers from a vector of Values.
3765+
auto getIntShape = [](auto &&vals) {
3766+
SmallVector<int64_t> shape;
3767+
shape.reserve(vals.size());
3768+
for (auto v : vals) {
3769+
int64_t cst_val;
3770+
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3771+
shape.push_back(cst_val);
3772+
} else {
3773+
shape.push_back(kUnknownSize);
3774+
}
3775+
}
3776+
return shape;
3777+
};
3778+
3779+
const auto intShape = getIntShape(vals);
3780+
return ValueTensorType::get(vals[0].getContext(),
3781+
llvm::ArrayRef(intShape), inOptionalDType);
3782+
};
3783+
3784+
auto nLeadingDims = inRank - 3;
3785+
3786+
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3787+
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
3788+
// folded to a ConstantOp.
3789+
auto getDimSize = [&](uint64_t i) -> Value {
3790+
Value dim =
3791+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3792+
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3793+
};
3794+
3795+
auto inC = getDimSize(inRank - 3);
3796+
auto inH = getDimSize(inRank - 2);
3797+
auto inW = getDimSize(inRank - 1);
3798+
3799+
auto factor = op.getDownscaleFactor();
3800+
3801+
Value factorSquared =
3802+
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3803+
3804+
Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3805+
3806+
Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3807+
Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3808+
3809+
SmallVector<Value> dimensionConstants;
3810+
dimensionConstants.reserve(inRank + 2);
3811+
for (unsigned i = 0; i < inRank + 2; ++i) {
3812+
dimensionConstants.push_back(
3813+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3814+
}
3815+
3816+
SmallVector<Value> leadingDims;
3817+
leadingDims.reserve(nLeadingDims);
3818+
for (unsigned i = 0; i < nLeadingDims; ++i) {
3819+
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3820+
loc, inValue, dimensionConstants[i]);
3821+
leadingDims.push_back(leadingDimSize);
3822+
}
3823+
3824+
SmallVector<Value> partiallyExpandedShape = leadingDims;
3825+
partiallyExpandedShape.append({inC, outH, factor, inW});
3826+
3827+
SmallVector<Value> prePermuteShape = leadingDims;
3828+
prePermuteShape.append({inC, outH, factor, outW, factor});
3829+
3830+
SmallVector<Value> postPermuteShape = leadingDims;
3831+
postPermuteShape.append({inC, factor, factor, outH, outW});
3832+
3833+
SmallVector<Value> partiallyCollapsedShape = leadingDims;
3834+
partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3835+
3836+
SmallVector<Value> outShape = leadingDims;
3837+
outShape.append({outC, outH, outW});
3838+
3839+
SmallVector<Value> permutation{dimensionConstants.begin(),
3840+
dimensionConstants.begin() + nLeadingDims};
3841+
SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3842+
for (uint64_t d : permutationTail) {
3843+
permutation.push_back(dimensionConstants[nLeadingDims + d]);
3844+
}
3845+
3846+
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3847+
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3848+
permutation);
3849+
3850+
// Split input channel inH -> (outH, factor)
3851+
auto partiallyExpanded =
3852+
rewriter
3853+
.create<PrimsSplitDimOp>(
3854+
loc, getTypeFromShape(partiallyExpandedShape), inValue,
3855+
dimensionConstants[nLeadingDims + 1], outH)
3856+
.getResult();
3857+
3858+
// Split new dimension inW -> (outW, factor)
3859+
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3860+
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
3861+
dimensionConstants[nLeadingDims + 3], outW);
3862+
3863+
// Perform the permutation
3864+
auto permuted =
3865+
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
3866+
fullyExpanded, permuteDimsOrder);
3867+
3868+
// Collapse final 2 dimension
3869+
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3870+
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
3871+
dimensionConstants[nLeadingDims + 1],
3872+
dimensionConstants[nLeadingDims + 2]);
3873+
3874+
// Collapse back to original rank
3875+
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3876+
op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims],
3877+
dimensionConstants[nLeadingDims + 1]);
3878+
3879+
return success();
3880+
}
3881+
};
3882+
} // namespace
3883+
37133884
// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
37143885
// prims.collapse operations.
37153886
//
@@ -12859,6 +13030,7 @@ class DecomposeComplexOpsPass
1285913030
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1286013031
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1286113032
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
13033+
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
1286213034
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
1286313035
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1286413036
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
421421
target.addIllegalOp<Aten_LinalgDetOp>();
422422
target.addIllegalOp<AtenLinalgSlogdetOp>();
423423
target.addIllegalOp<AtenPixelShuffleOp>();
424+
target.addIllegalOp<AtenPixelUnshuffleOp>();
424425
target.addIllegalOp<AtenChannelShuffleOp>();
425426
target.addIllegalOp<AtenTOp>();
426427
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ bool Torch::isViewLikeOp(Operation *op) {
327327
AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
328328
PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp,
329329
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp,
330-
AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
330+
AtenPixelUnshuffleOp, AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
331331
}
332332

333333
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,12 @@
819819
"PrimsSqueezeModule_basic",
820820
"PrimsViewOfModule_basic",
821821
"PrimsViewOfZeroRankModule_basic",
822+
"PixelUnshuffleModuleFullDynamic_basic",
823+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
824+
"PixelUnshuffleModuleSpatiallyStatic_basic",
825+
"PixelUnshuffleModuleStaticRank3Int64_basic",
826+
"PixelUnshuffleModuleStaticRank4Float32_basic",
827+
"PixelUnshuffleModuleStaticRank5Float32_basic",
822828
"QuantizedBatchedInputSingleLayer_basic",
823829
"QuantizedMLP_basic",
824830
"QuantizedNoLayer_basic",
@@ -3127,6 +3133,11 @@
31273133
"PixelShuffleModuleSpatiallyDynamic_basic",
31283134
"PixelShuffleModuleSpatiallyStatic_basic",
31293135
"PixelShuffleModuleStaticRank3Int64_basic",
3136+
"PixelUnshuffleModuleStaticRank5Float32_basic",
3137+
"PixelUnshuffleModuleStaticRank3Int64_basic",
3138+
"PixelUnshuffleModuleFullDynamic_basic",
3139+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
3140+
"PixelUnshuffleModuleSpatiallyStatic_basic",
31303141
"ChannelShuffleBasic_basic",
31313142
"ChannelShuffleUnitaryGroup_basic",
31323143
"ChannelShuffle1D_basic",
@@ -4738,6 +4749,11 @@
47384749
"PixelShuffleModuleSpatiallyStatic_basic",
47394750
"PixelShuffleModuleStaticRank3Int64_basic",
47404751
"PixelShuffleModuleStaticRank4Float32_basic",
4752+
"PixelUnshuffleModuleStaticRank5Float32_basic",
4753+
"PixelUnshuffleModuleStaticRank3Int64_basic",
4754+
"PixelUnshuffleModuleFullDynamic_basic",
4755+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
4756+
"PixelUnshuffleModuleSpatiallyStatic_basic",
47414757
"ChannelShuffleBasic_basic",
47424758
"ChannelShuffleUnitaryGroup_basic",
47434759
"ChannelShuffle1D_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,21 @@ def aten〇channel_shuffle〡shape(self: List[int], groups: int) -> List[int]:
843843
assert len(self) >= 3, "input must be at least rank-3 in channel_shuffle"
844844
return self
845845

846+
def aten〇pixel_unshuffle〡shape(self: List[int], downscale_factor: int) -> List[int]:
847+
848+
assert len(self) >= 3, "input must be at least rank-3 in pixel_unshuffle"
849+
downscale_factor_squared = downscale_factor * downscale_factor
850+
assert self[-2] % (downscale_factor) == 0, "height must be divisible by downscale_factor in pixel_unshuffle"
851+
assert self[-1] % (downscale_factor) == 0, "width must be divisible by downscale_factor in pixel_unshuffle"
852+
853+
out = self[0:-3]
854+
out.append(self[-3] * downscale_factor_squared)
855+
out.append(self[-2] // downscale_factor)
856+
out.append(self[-1] // downscale_factor)
857+
return out
858+
859+
860+
846861
def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]:
847862
return upstream_shape_functions.permute(self, dims)
848863

@@ -3069,6 +3084,11 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
30693084
self_rank, self_dtype = self_rank_dtype
30703085
return self_dtype
30713086

3087+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 2, 2)], downscale_factor = 2))
3088+
def aten〇pixel_unshuffle〡dtype(self_rank_dtype: Tuple[int, int], downscale_factor: int) -> int:
3089+
self_rank, self_dtype = self_rank_dtype
3090+
return self_dtype
3091+
30723092
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 4, 4, 5)], groups = 2))
30733093
def aten〇channel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], groups: int) -> int:
30743094
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ def emit_with_mutating_variants(key, **kwargs):
719719
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
720720
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True)
721721
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
722+
emit("aten::pixel_unshuffle : (Tensor, int) -> (Tensor)")
722723
emit("aten::channel_shuffle : (Tensor, int) -> (Tensor)")
723724
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
724725
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")

0 commit comments

Comments
 (0)