-
Notifications
You must be signed in to change notification settings - Fork 639
fix: infer output shape directly in ConvertAtenUnflattenIntOp #4325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
71529af
a9c0f41
514c0c9
98a90a3
0c19f99
96088b8
14e1c19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" | ||
|
||
#include "PopulatePatterns.h" | ||
#include "mlir/Dialect/Affine/Utils.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Complex/IR/Complex.h" | ||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
|
@@ -671,7 +672,8 @@ class ConvertAtenUnflattenIntOp | |
return rewriter.notifyMatchFailure(op, | ||
"Expected input type having sizes"); | ||
} | ||
int inputRank = inputTensorType.getSizes().size(); | ||
auto inputTensorSizes = inputTensorType.getSizes(); | ||
int inputRank = inputTensorSizes.size(); | ||
auto outputSizes = outputTensorType.getSizes(); | ||
int outputRank = outputSizes.size(); | ||
|
||
|
@@ -692,15 +694,15 @@ class ConvertAtenUnflattenIntOp | |
if (outputSizes[i] == Torch::kUnknownSize) | ||
numDynamicReassocDims++; | ||
} | ||
|
||
SmallVector<Value> reassocSizes; | ||
if (!getListConstructElements(op.getSizes(), reassocSizes) && | ||
numDynamicReassocDims > 1) | ||
return rewriter.notifyMatchFailure( | ||
op, "Must be able to either infer expansion dims, or retrieve them " | ||
"from list construct"); | ||
|
||
auto expandTy = getTypeConverter()->convertType(outputTensorType); | ||
RankedTensorType expandTy = cast<RankedTensorType>( | ||
getTypeConverter()->convertType(outputTensorType)); | ||
Value expand; | ||
// When there are less than two dynamic reassociation dims, this will lower | ||
// to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. | ||
|
@@ -717,10 +719,80 @@ class ConvertAtenUnflattenIntOp | |
for (int i = dimInt + numSizes; i < outputRank; ++i) | ||
reassociations[i - numSizes + 1].push_back(i); | ||
} | ||
expand = rewriter | ||
.create<tensor::ExpandShapeOp>( | ||
loc, expandTy, adaptor.getSelf(), reassociations) | ||
|
||
auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { | ||
int64_t constantSize; | ||
if (matchPattern(sizeVal, m_TorchConstantInt(&constantSize))) { | ||
return rewriter.getIndexAttr(constantSize); | ||
} | ||
SmallVector<Value> singleSizeVec = {sizeVal}; | ||
Value converted = castIntToIndex( | ||
rewriter, loc, | ||
getTypeConvertedValues(rewriter, loc, getTypeConverter(), | ||
singleSizeVec)[0]); | ||
return OpFoldResult(converted); | ||
}; | ||
|
||
int64_t minusOneIdx = -1; | ||
OpFoldResult knownProduct = rewriter.getIndexAttr(1); | ||
AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); | ||
AffineExpr s1 = getAffineSymbolExpr(1, rewriter.getContext()); | ||
auto mulMap = AffineMap::get(0, 2, s0 * s1, rewriter.getContext()); | ||
|
||
for (int64_t j = 0, e = reassocSizes.size(); j < e; ++j) { | ||
int64_t constantSize; | ||
if (matchPattern(reassocSizes[j], m_TorchConstantInt(&constantSize)) && | ||
constantSize == -1) { | ||
minusOneIdx = j; | ||
} else { | ||
knownProduct = affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, mulMap, | ||
{knownProduct, sizeToOFR(reassocSizes[j])}); | ||
} | ||
} | ||
|
||
SmallVector<OpFoldResult> outputShape; | ||
SmallVector<Value> inputSizes = | ||
getTensorSizes(rewriter, loc, adaptor.getSelf()); | ||
for (int64_t i = 0; i < inputRank; ++i) { | ||
if (i != dimInt) { | ||
OpFoldResult inputDimSize = | ||
(inputTensorSizes[i] != Torch::kUnknownSize) | ||
? rewriter.getIndexAttr(inputTensorSizes[i]) | ||
: OpFoldResult(inputSizes[i]); | ||
outputShape.push_back(inputDimSize); | ||
continue; | ||
} | ||
|
||
OpFoldResult inputDimSize = | ||
(inputTensorSizes[dimInt] != Torch::kUnknownSize) | ||
? rewriter.getIndexAttr(inputTensorSizes[dimInt]) | ||
: OpFoldResult(inputSizes[dimInt]); | ||
for (int64_t j = 0; j < numSizes; ++j) { | ||
if (j == minusOneIdx) { | ||
auto divMap = | ||
AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); | ||
outputShape.push_back(affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, divMap, {inputDimSize, knownProduct})); | ||
} else { | ||
outputShape.push_back(sizeToOFR(reassocSizes[j])); | ||
} | ||
} | ||
} | ||
|
||
SmallVector<int64_t> resultShape = | ||
decomposeMixedValues(outputShape).first; | ||
auto resultType = | ||
RankedTensorType::get(resultShape, expandTy.getElementType()); | ||
expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, | ||
adaptor.getSelf(), reassociations, | ||
outputShape) | ||
.getResult(); | ||
|
||
if (resultType != expandTy) { | ||
expand = | ||
rewriter.create<tensor::CastOp>(loc, expandTy, expand).getResult(); | ||
} | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this path still needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so. I tried without, doesn't seem to be possible. Unless you had a specific patch in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see when it would need to create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be honest I'm not sure either outside of the explicit check, for which I've added a case for now. But otherwise I'm not sure if it's possible to simplify the logic. Sometime last week I tested trying to combine the two and it didn't work. |
||
reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), | ||
reassocSizes); | ||
|
@@ -745,6 +817,7 @@ class ConvertAtenUnflattenIntOp | |
shapeValue) | ||
.getResult(); | ||
} | ||
|
||
rewriter.replaceOp(op, expand); | ||
return success(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten.unflatten.int$static | ||
// CHECK: torch_c.to_builtin_tensor | ||
// CHECK: tensor.expand_shape | ||
// CHECK: torch_c.from_builtin_tensor | ||
func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { | ||
%int1 = torch.constant.int 1 | ||
%int2 = torch.constant.int 2 | ||
%int3 = torch.constant.int 3 | ||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32> | ||
return %1 : !torch.vtensor<[2,2,3,4],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim | ||
// CHECK: torch_c.to_builtin_tensor | ||
// CHECK: tensor.expand_shape | ||
// CHECK: torch_c.from_builtin_tensor | ||
func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { | ||
%int-2 = torch.constant.int -2 | ||
%int2 = torch.constant.int 2 | ||
%int3 = torch.constant.int 3 | ||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32> | ||
return %1 : !torch.vtensor<[2,2,3,4],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size | ||
// CHECK: torch_c.to_builtin_tensor | ||
// CHECK: tensor.expand_shape | ||
// CHECK: torch_c.from_builtin_tensor | ||
func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> { | ||
%int1 = torch.constant.int 1 | ||
%int2 = torch.constant.int 2 | ||
%int-1 = torch.constant.int -1 | ||
%0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[3,2,6],f32> | ||
return %1 : !torch.vtensor<[3,2,6],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input | ||
// CHECK: torch_c.to_builtin_tensor | ||
// CHECK: tensor.expand_shape | ||
// CHECK: torch_c.from_builtin_tensor | ||
func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> { | ||
%int1 = torch.constant.int 1 | ||
%int2 = torch.constant.int 2 | ||
%int3 = torch.constant.int 3 | ||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,2,3],f32> | ||
return %1 : !torch.vtensor<[?,2,3],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims | ||
// CHECK: torch_c.to_builtin_tensor | ||
// CHECK: tensor.from_elements | ||
// CHECK: tensor.reshape | ||
// CHECK: torch_c.from_builtin_tensor | ||
func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> { | ||
%int1 = torch.constant.int 1 | ||
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int | ||
%0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,?],f32> | ||
return %1 : !torch.vtensor<[?,?,?],f32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand this change. If the old script is broken (which would be very odd), this should definitely be made into a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that trying
cmake --build build --target check-torch-mlir-python
would ask "Did you meancheck-torch_mlir-python
" so I made that change. But I will make this a separate PR.