-
Notifications
You must be signed in to change notification settings - Fork 628
[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op #4278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9068bbc
be02c1f
9b44fc7
f409948
111dfda
82215f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
#include "llvm/ADT/StringExtras.h" | ||
#include "llvm/ADT/StringSet.h" | ||
#include <cstdint> | ||
#include <iostream> | ||
#include <set> | ||
|
||
using namespace mlir; | ||
|
@@ -3708,6 +3709,177 @@ class DecomposeAtenPixelShuffleOp | |
}; | ||
} // namespace | ||
|
||
// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and | ||
// prims.collapse operations. | ||
// | ||
// We want to do the exact opposite of aten.pixel_shuffle | ||
// | ||
// If input is a tensor of shape | ||
// (*leading_dims, C, H*r, W*r), | ||
// | ||
// where leading_dims is of size N, then | ||
// X = pixel_unshuffle(input, downscale_factor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
// | ||
// gets replaced with | ||
// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r) | ||
alaa-ali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r) | ||
// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3) | ||
// # shape (*leading_dims, C, r, r, H, W) | ||
// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W) | ||
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W) | ||
Comment on lines
+3728
to
+3729
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need two collapses -- isn't collapsing directly to |
||
// | ||
// 'r' above is referred to as the 'downscale factor' or just 'factor' below. | ||
namespace { | ||
class DecomposeAtenPixelUnshuffleOp | ||
: public OpRewritePattern<AtenPixelUnshuffleOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
Location loc = op.getLoc(); | ||
Value inValue = op.getSelf(); | ||
auto inType = cast<BaseTensorType>(inValue.getType()); | ||
auto maybeSizes = inType.getOptionalSizes(); | ||
if (!maybeSizes) { | ||
return rewriter.notifyMatchFailure( | ||
op, "Expected input tensor to have known rank."); | ||
} | ||
auto inShape = maybeSizes.value(); | ||
auto inRank = inShape.size(); | ||
|
||
// The input tensor must have at least 3 dimensions: (1) the channel | ||
// dimension which gets bigger by 'factor*factor', (2) the H channel which | ||
// gets smaller by 'factor' and (3) the W channel which get smaller by | ||
// 'factor'. The total number of dimensions is 3 + N, where N is the number | ||
// of leading dimensions, and N >= 0 so the input must have rank at least 3. | ||
if (inRank < 3) | ||
return rewriter.notifyMatchFailure( | ||
op, "Expected input tensor to have rank greater than 2."); | ||
|
||
const auto inOptionalDType = inType.getOptionalDtype(); | ||
|
||
auto getTypeFromShape = [inOptionalDType](auto &&vals) { | ||
// Get a vector of integers from a vector of Values. | ||
auto getIntShape = [](auto &&vals) { | ||
SmallVector<int64_t> shape; | ||
shape.reserve(vals.size()); | ||
for (auto v : vals) { | ||
int64_t cst_val; | ||
if (matchPattern(v, m_TorchConstantInt(&cst_val))) { | ||
shape.push_back(cst_val); | ||
} else { | ||
shape.push_back(kUnknownSize); | ||
} | ||
} | ||
return shape; | ||
}; | ||
|
||
const auto intShape = getIntShape(vals); | ||
return ValueTensorType::get(vals[0].getContext(), | ||
llvm::ArrayRef(intShape), inOptionalDType); | ||
}; | ||
Comment on lines
+3762
to
+3781
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These methods were added as utilities in #4259. Once that is merged, can you update your code to reuse the utilities? |
||
|
||
auto nLeadingDims = inRank - 3; | ||
|
||
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead | ||
// of 'create': if the dimension size is known, then the AtenSizeIntOp is | ||
// folded to a ConstantOp. | ||
auto getDimSize = [&](uint64_t i) -> Value { | ||
Value dim = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)); | ||
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim); | ||
}; | ||
Comment on lines
+3785
to
+3792
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is shared with #4259 as well. It'll be good to move this into an utility method and probably move all these utilities to |
||
|
||
auto inC = getDimSize(inRank - 3); | ||
auto inH = getDimSize(inRank - 2); | ||
auto inW = getDimSize(inRank - 1); | ||
|
||
auto factor = op.getDownscaleFactor(); | ||
|
||
Value factorSquared = | ||
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor); | ||
|
||
Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared); | ||
|
||
Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor); | ||
Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor); | ||
|
||
SmallVector<Value> dimensionConstants; | ||
dimensionConstants.reserve(inRank + 2); | ||
for (unsigned i = 0; i < inRank + 2; ++i) { | ||
dimensionConstants.push_back( | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i))); | ||
} | ||
|
||
SmallVector<Value> leadingDims; | ||
leadingDims.reserve(nLeadingDims); | ||
for (unsigned i = 0; i < nLeadingDims; ++i) { | ||
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>( | ||
loc, inValue, dimensionConstants[i]); | ||
leadingDims.push_back(leadingDimSize); | ||
} | ||
|
||
SmallVector<Value> partiallyExpandedShape = leadingDims; | ||
partiallyExpandedShape.append({inC, outH, factor, inW}); | ||
Comment on lines
+3823
to
+3824
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move this before it's use also rename to |
||
|
||
SmallVector<Value> prePermuteShape = leadingDims; | ||
prePermuteShape.append({inC, outH, factor, outW, factor}); | ||
|
||
SmallVector<Value> postPermuteShape = leadingDims; | ||
postPermuteShape.append({inC, factor, factor, outH, outW}); | ||
|
||
SmallVector<Value> partiallyCollapsedShape = leadingDims; | ||
partiallyCollapsedShape.append({inC, factorSquared, outH, outW}); | ||
|
||
SmallVector<Value> outShape = leadingDims; | ||
outShape.append({outC, outH, outW}); | ||
|
||
SmallVector<Value> permutation{dimensionConstants.begin(), | ||
dimensionConstants.begin() + nLeadingDims}; | ||
SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3}; | ||
for (uint64_t d : permutationTail) { | ||
permutation.push_back(dimensionConstants[nLeadingDims + d]); | ||
} | ||
|
||
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>( | ||
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), | ||
permutation); | ||
|
||
// Split input channel inH -> (outH, factor) | ||
auto partiallyExpanded = | ||
rewriter | ||
.create<PrimsSplitDimOp>( | ||
loc, getTypeFromShape(partiallyExpandedShape), inValue, | ||
dimensionConstants[nLeadingDims + 1], outH) | ||
.getResult(); | ||
|
||
// Split new dimension inW -> (outW, factor) | ||
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>( | ||
loc, getTypeFromShape(prePermuteShape), partiallyExpanded, | ||
dimensionConstants[nLeadingDims + 3], outW); | ||
|
||
// Perform the permutation | ||
auto permuted = | ||
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape), | ||
fullyExpanded, permuteDimsOrder); | ||
|
||
// Collapse final 2 dimension | ||
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>( | ||
loc, getTypeFromShape(partiallyCollapsedShape), permuted, | ||
dimensionConstants[nLeadingDims + 1], | ||
dimensionConstants[nLeadingDims + 2]); | ||
|
||
// Collapse back to original rank | ||
rewriter.replaceOpWithNewOp<PrimsCollapseOp>( | ||
op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims], | ||
dimensionConstants[nLeadingDims + 1]); | ||
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) | ||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc, | ||
Value input) { | ||
|
@@ -12518,6 +12690,7 @@ class DecomposeComplexOpsPass | |
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns); | ||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns); | ||
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns); | ||
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns); | ||
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns); | ||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>( | ||
patterns); | ||
|
Uh oh!
There was an error while loading. Please reload this page.