Skip to content

Commit 9068bbc

Browse files
committed
add lowering torch.aten.pixel_unshuffle op to linalg
1 parent 2c989a2 commit 9068bbc

File tree

10 files changed

+432
-11
lines changed

10 files changed

+432
-11
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
86688668
}];
86698669
}
86708670

8671+
def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [
8672+
AllowsTypeRefinement,
8673+
HasValueSemantics,
8674+
ReadOnly
8675+
]> {
8676+
let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`";
8677+
let arguments = (ins
8678+
AnyTorchTensorType:$self,
8679+
Torch_IntType:$downscale_factor
8680+
);
8681+
let results = (outs
8682+
AnyTorchOptionalTensorType:$result
8683+
);
8684+
let hasCustomAssemblyFormat = 1;
8685+
let extraClassDefinition = [{
8686+
ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) {
8687+
return parseDefaultTorchOp(parser, result, 2, 1);
8688+
}
8689+
void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) {
8690+
printDefaultTorchOp(printer, *this, 2, 1);
8691+
}
8692+
}];
8693+
}
8694+
86718695
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
86728696
AllowsTypeRefinement,
86738697
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7613,6 +7613,56 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
76137613
" %15 = torch.aten.append.t %6, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
76147614
" return %6 : !torch.list<int>\n"
76157615
" }\n"
7616+
" func.func @\"__torch_mlir_shape_fn.aten.pixel_unshuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
7617+
" %int1 = torch.constant.int 1\n"
7618+
" %int-3 = torch.constant.int -3\n"
7619+
" %str = torch.constant.str \"AssertionError: width must be divisible by downscale_factor in pixel_unshuffle\"\n"
7620+
" %int-1 = torch.constant.int -1\n"
7621+
" %str_0 = torch.constant.str \"AssertionError: height must be divisible by downscale_factor in pixel_unshuffle\"\n"
7622+
" %int-2 = torch.constant.int -2\n"
7623+
" %none = torch.constant.none\n"
7624+
" %str_1 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_unshuffle\"\n"
7625+
" %int3 = torch.constant.int 3\n"
7626+
" %int0 = torch.constant.int 0\n"
7627+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
7628+
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
7629+
" torch.prim.If %1 -> () {\n"
7630+
" torch.prim.If.yield\n"
7631+
" } else {\n"
7632+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
7633+
" torch.prim.If.yield\n"
7634+
" }\n"
7635+
" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7636+
" %3 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
7637+
" %4 = torch.aten.remainder.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7638+
" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7639+
" torch.prim.If %5 -> () {\n"
7640+
" torch.prim.If.yield\n"
7641+
" } else {\n"
7642+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
7643+
" torch.prim.If.yield\n"
7644+
" }\n"
7645+
" %6 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
7646+
" %7 = torch.aten.remainder.int %6, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7647+
" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7648+
" torch.prim.If %8 -> () {\n"
7649+
" torch.prim.If.yield\n"
7650+
" } else {\n"
7651+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
7652+
" torch.prim.If.yield\n"
7653+
" }\n"
7654+
" %9 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list<int>, !torch.int, !torch.int, !torch.int -> !torch.list<int>\n"
7655+
" %10 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
7656+
" %11 = torch.aten.mul.int %10, %2 : !torch.int, !torch.int -> !torch.int\n"
7657+
" %12 = torch.aten.append.t %9, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7658+
" %13 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
7659+
" %14 = torch.aten.floordiv.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7660+
" %15 = torch.aten.append.t %9, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7661+
" %16 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
7662+
" %17 = torch.aten.floordiv.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7663+
" %18 = torch.aten.append.t %9, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7664+
" return %9 : !torch.list<int>\n"
7665+
" }\n"
76167666
" func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
76177667
" %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
76187668
" return %0 : !torch.list<int>\n"
@@ -12275,6 +12325,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1227512325
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1227612326
" return %0#1 : !torch.int\n"
1227712327
" }\n"
12328+
" func.func @\"__torch_mlir_dtype_fn.aten.pixel_unshuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12329+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12330+
" return %0#1 : !torch.int\n"
12331+
" }\n"
1227812332
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
1227912333
" %none = torch.constant.none\n"
1228012334
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/ADT/StringExtras.h"
2424
#include "llvm/ADT/StringSet.h"
2525
#include <cstdint>
26+
#include <iostream>
2627
#include <set>
2728

2829
using namespace mlir;
@@ -3708,6 +3709,177 @@ class DecomposeAtenPixelShuffleOp
37083709
};
37093710
} // namespace
37103711

3712+
// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3713+
// prims.collapse operations.
3714+
//
3715+
// We want to do the exact opposite of aten.pixel_shuffle
3716+
//
3717+
// If input is a tensor of shape
3718+
// (*leading_dims, C, H*r, W*r),
3719+
//
3720+
// where leading_dims is of size N, then
3721+
// X = pixel_unshuffle(input, downscale_factor)
3722+
//
3723+
// gets replaced with
3724+
// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3725+
// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3726+
// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3727+
// # shape (*leading_dims, C, r, r, H, W)
3728+
// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W)
3729+
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3730+
//
3731+
// 'r' above is referred to as the 'downscale factor' or just 'factor' below.
3732+
namespace {
3733+
class DecomposeAtenPixelUnshuffleOp
3734+
: public OpRewritePattern<AtenPixelUnshuffleOp> {
3735+
public:
3736+
using OpRewritePattern::OpRewritePattern;
3737+
LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3738+
PatternRewriter &rewriter) const override {
3739+
3740+
Location loc = op.getLoc();
3741+
Value inValue = op.getSelf();
3742+
auto inType = cast<BaseTensorType>(inValue.getType());
3743+
auto maybeSizes = inType.getOptionalSizes();
3744+
if (!maybeSizes) {
3745+
return rewriter.notifyMatchFailure(
3746+
op, "Expected input tensor to have known rank.");
3747+
}
3748+
auto inShape = maybeSizes.value();
3749+
auto inRank = inShape.size();
3750+
3751+
// The input tensor must have at least 3 dimensions: (1) the channel
3752+
// dimension which gets bigger by 'factor*factor', (2) the H channel which
3753+
// gets smaller by 'factor' and (3) the W channel which get smaller by
3754+
// 'factor'. The total number of dimensions is 3 + N, where N is the number
3755+
// of leading dimensions, and N >= 0 so the input must have rank at least 3.
3756+
if (inRank < 3)
3757+
return rewriter.notifyMatchFailure(
3758+
op, "Expected input tensor to have rank greater than 2.");
3759+
3760+
const auto inOptionalDType = inType.getOptionalDtype();
3761+
3762+
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3763+
// Get a vector of integers from a vector of Values.
3764+
auto getIntShape = [](auto &&vals) {
3765+
SmallVector<int64_t> shape;
3766+
shape.reserve(vals.size());
3767+
for (auto v : vals) {
3768+
int64_t cst_val;
3769+
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3770+
shape.push_back(cst_val);
3771+
} else {
3772+
shape.push_back(kUnknownSize);
3773+
}
3774+
}
3775+
return shape;
3776+
};
3777+
3778+
const auto intShape = getIntShape(vals);
3779+
return ValueTensorType::get(vals[0].getContext(),
3780+
llvm::ArrayRef(intShape), inOptionalDType);
3781+
};
3782+
3783+
auto nLeadingDims = inRank - 3;
3784+
3785+
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3786+
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
3787+
// folded to a ConstantOp.
3788+
auto getDimSize = [&](uint64_t i) -> Value {
3789+
Value dim =
3790+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3791+
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3792+
};
3793+
3794+
auto inC = getDimSize(inRank - 3);
3795+
auto inH = getDimSize(inRank - 2);
3796+
auto inW = getDimSize(inRank - 1);
3797+
3798+
auto factor = op.getDownscaleFactor();
3799+
3800+
Value factorSquared =
3801+
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3802+
3803+
Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3804+
3805+
Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3806+
Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3807+
3808+
SmallVector<Value> dimensionConstants;
3809+
dimensionConstants.reserve(inRank + 2);
3810+
for (unsigned i = 0; i < inRank + 2; ++i) {
3811+
dimensionConstants.push_back(
3812+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3813+
}
3814+
3815+
SmallVector<Value> leadingDims;
3816+
leadingDims.reserve(nLeadingDims);
3817+
for (unsigned i = 0; i < nLeadingDims; ++i) {
3818+
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3819+
loc, inValue, dimensionConstants[i]);
3820+
leadingDims.push_back(leadingDimSize);
3821+
}
3822+
3823+
SmallVector<Value> partiallyExpandedShape = leadingDims;
3824+
partiallyExpandedShape.append({inC, outH, factor, inW});
3825+
3826+
SmallVector<Value> prePermuteShape = leadingDims;
3827+
prePermuteShape.append({inC, outH, factor, outW, factor});
3828+
3829+
SmallVector<Value> postPermuteShape = leadingDims;
3830+
postPermuteShape.append({inC, factor, factor, outH, outW});
3831+
3832+
SmallVector<Value> partiallyCollapsedShape = leadingDims;
3833+
partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3834+
3835+
SmallVector<Value> outShape = leadingDims;
3836+
outShape.append({outC, outH, outW});
3837+
3838+
SmallVector<Value> permutation{dimensionConstants.begin(),
3839+
dimensionConstants.begin() + nLeadingDims};
3840+
SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3841+
for (uint64_t d : permutationTail) {
3842+
permutation.push_back(dimensionConstants[nLeadingDims + d]);
3843+
}
3844+
3845+
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3846+
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3847+
permutation);
3848+
3849+
// Split input channel inH -> (outH, factor)
3850+
auto partiallyExpanded =
3851+
rewriter
3852+
.create<PrimsSplitDimOp>(
3853+
loc, getTypeFromShape(partiallyExpandedShape), inValue,
3854+
dimensionConstants[nLeadingDims + 1], outH)
3855+
.getResult();
3856+
3857+
// Split new dimension inW -> (outW, factor)
3858+
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3859+
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
3860+
dimensionConstants[nLeadingDims + 3], outW);
3861+
3862+
// Perform the permutation
3863+
auto permuted =
3864+
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
3865+
fullyExpanded, permuteDimsOrder);
3866+
3867+
// Collapse final 2 dimension
3868+
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3869+
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
3870+
dimensionConstants[nLeadingDims + 1],
3871+
dimensionConstants[nLeadingDims + 2]);
3872+
3873+
// Collapse back to original rank
3874+
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3875+
op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims],
3876+
dimensionConstants[nLeadingDims + 1]);
3877+
3878+
return success();
3879+
}
3880+
};
3881+
} // namespace
3882+
37113883
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
37123884
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
37133885
Value input) {
@@ -12514,6 +12686,7 @@ class DecomposeComplexOpsPass
1251412686
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1251512687
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1251612688
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
12689+
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
1251712690
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1251812691
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
1251912692
patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
421421
target.addIllegalOp<Aten_LinalgDetOp>();
422422
target.addIllegalOp<AtenLinalgSlogdetOp>();
423423
target.addIllegalOp<AtenPixelShuffleOp>();
424+
target.addIllegalOp<AtenPixelUnshuffleOp>();
424425
target.addIllegalOp<AtenTOp>();
425426
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
426427
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,17 @@ bool Torch::isViewLikeOp(Operation *op) {
317317
// correct. We could potentially be more precise and identify the cases
318318
// that it does not return a view and treat those as having value
319319
// semantics.
320-
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
321-
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
322-
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
323-
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
324-
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
325-
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
326-
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
327-
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
328-
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
329-
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
330-
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
320+
return isa<
321+
AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
322+
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
323+
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
324+
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
325+
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
326+
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp,
327+
AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
328+
PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp,
329+
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp,
330+
AtenPixelUnshuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
331331
}
332332

333333
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,12 @@
826826
"PrimsSqueezeModule_basic",
827827
"PrimsViewOfModule_basic",
828828
"PrimsViewOfZeroRankModule_basic",
829+
"PixelUnshuffleModuleFullDynamic_basic",
830+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
831+
"PixelUnshuffleModuleSpatiallyStatic_basic",
832+
"PixelUnshuffleModuleStaticRank3Int64_basic",
833+
"PixelUnshuffleModuleStaticRank4Float32_basic",
834+
"PixelUnshuffleModuleStaticRank5Float32_basic",
829835
"QuantizedBatchedInputSingleLayer_basic",
830836
"QuantizedMLP_basic",
831837
"QuantizedNoLayer_basic",
@@ -3120,6 +3126,11 @@
31203126
"PixelShuffleModuleSpatiallyDynamic_basic",
31213127
"PixelShuffleModuleSpatiallyStatic_basic",
31223128
"PixelShuffleModuleStaticRank3Int64_basic",
3129+
"PixelUnshuffleModuleStaticRank5Float32_basic",
3130+
"PixelUnshuffleModuleStaticRank3Int64_basic",
3131+
"PixelUnshuffleModuleFullDynamic_basic",
3132+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
3133+
"PixelUnshuffleModuleSpatiallyStatic_basic",
31233134
"PowIntIntModule_basic",
31243135
"PrimMaxIntModule_basic",
31253136
"PrimMinIntDynamicModule_basic",
@@ -4706,6 +4717,11 @@
47064717
"PixelShuffleModuleSpatiallyStatic_basic",
47074718
"PixelShuffleModuleStaticRank3Int64_basic",
47084719
"PixelShuffleModuleStaticRank4Float32_basic",
4720+
"PixelUnshuffleModuleStaticRank5Float32_basic",
4721+
"PixelUnshuffleModuleStaticRank3Int64_basic",
4722+
"PixelUnshuffleModuleFullDynamic_basic",
4723+
"PixelUnshuffleModuleSpatiallyDynamic_basic",
4724+
"PixelUnshuffleModuleSpatiallyStatic_basic",
47094725
"PrimMaxIntModule_basic",
47104726
"PrimMinIntDynamicModule_basic",
47114727
"PrimMinIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,19 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i
839839
out.append(self[-1] * upscale_factor)
840840
return out
841841

842+
def aten〇pixel_unshuffle〡shape(self: List[int], downscale_factor: int) -> List[int]:
843+
844+
assert len(self) >= 3, "input must be at least rank-3 in pixel_unshuffle"
845+
downscale_factor_squared = downscale_factor * downscale_factor
846+
assert self[-2] % (downscale_factor) == 0, "height must be divisible by downscale_factor in pixel_unshuffle"
847+
assert self[-1] % (downscale_factor) == 0, "width must be divisible by downscale_factor in pixel_unshuffle"
848+
849+
out = self[0:-3]
850+
out.append(self[-3] * downscale_factor_squared)
851+
out.append(self[-2] // downscale_factor)
852+
out.append(self[-1] // downscale_factor)
853+
return out
854+
842855

843856

844857
def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]:
@@ -3049,6 +3062,11 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
30493062
self_rank, self_dtype = self_rank_dtype
30503063
return self_dtype
30513064

3065+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 2, 2)], downscale_factor = 2))
3066+
def aten〇pixel_unshuffle〡dtype(self_rank_dtype: Tuple[int, int], downscale_factor: int) -> int:
3067+
self_rank, self_dtype = self_rank_dtype
3068+
return self_dtype
3069+
30523070
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8}))
30533071
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
30543072
self_rank, self_dtype = self_rank_dtype

0 commit comments

Comments
 (0)