-
Notifications
You must be signed in to change notification settings - Fork 639
fix: infer output shape directly in ConvertAtenUnflattenIntOp #4325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
71529af
a9c0f41
514c0c9
98a90a3
0c19f99
96088b8
14e1c19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" | ||
|
||
#include "PopulatePatterns.h" | ||
#include "mlir/Dialect/Affine/Utils.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Complex/IR/Complex.h" | ||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
|
@@ -692,15 +693,15 @@ class ConvertAtenUnflattenIntOp | |
if (outputSizes[i] == Torch::kUnknownSize) | ||
numDynamicReassocDims++; | ||
} | ||
|
||
SmallVector<Value> reassocSizes; | ||
if (!getListConstructElements(op.getSizes(), reassocSizes) && | ||
numDynamicReassocDims > 1) | ||
return rewriter.notifyMatchFailure( | ||
op, "Must be able to either infer expansion dims, or retrieve them " | ||
"from list construct"); | ||
|
||
auto expandTy = getTypeConverter()->convertType(outputTensorType); | ||
RankedTensorType expandTy = cast<RankedTensorType>( | ||
getTypeConverter()->convertType(outputTensorType)); | ||
Value expand; | ||
// When there are less than two dynamic reassociation dims, this will lower | ||
// to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. | ||
|
@@ -717,10 +718,102 @@ class ConvertAtenUnflattenIntOp | |
for (int i = dimInt + numSizes; i < outputRank; ++i) | ||
reassociations[i - numSizes + 1].push_back(i); | ||
} | ||
expand = rewriter | ||
.create<tensor::ExpandShapeOp>( | ||
loc, expandTy, adaptor.getSelf(), reassociations) | ||
|
||
// Is there a function that already does this somewhere? | ||
auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { | ||
int64_t constantSize; | ||
if (matchPattern(sizeVal, m_TorchConstantInt(&constantSize))) { | ||
return rewriter.getIndexAttr(constantSize); | ||
} | ||
SmallVector<Value> singleSizeVec = {sizeVal}; | ||
Value converted = castIntToIndex( | ||
rewriter, loc, | ||
getTypeConvertedValues(rewriter, loc, getTypeConverter(), | ||
singleSizeVec)[0]); | ||
return OpFoldResult(converted); | ||
}; | ||
|
||
int64_t minusOneIdx = -1; | ||
OpFoldResult knownProduct = rewriter.getIndexAttr(1); | ||
AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); | ||
AffineExpr s1 = getAffineSymbolExpr(1, rewriter.getContext()); | ||
auto mulMap = AffineMap::get(0, 2, s0 * s1, rewriter.getContext()); | ||
|
||
for (int64_t j = 0, e = reassocSizes.size(); j < e; ++j) { | ||
int64_t constantSize; | ||
// mlir::Value to int comparison... | ||
if (matchPattern(reassocSizes[j], m_TorchConstantInt(&constantSize)) && | ||
constantSize == -1) { | ||
minusOneIdx = j; | ||
} else { | ||
knownProduct = affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, mulMap, | ||
{knownProduct, sizeToOFR(reassocSizes[j])}); | ||
} | ||
} | ||
|
||
SmallVector<OpFoldResult> outputShape; | ||
SmallVector<Value> inputSizes = | ||
getTensorSizes(rewriter, loc, adaptor.getSelf()); | ||
for (int64_t i = 0; i < inputRank; ++i) { | ||
if (i == dimInt) { | ||
OpFoldResult inputDimSize = | ||
(inputTensorType.getSizes()[dimInt] != Torch::kUnknownSize) | ||
raayandhar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
? rewriter.getIndexAttr(inputTensorType.getSizes()[dimInt]) | ||
: OpFoldResult(inputSizes[dimInt]); | ||
for (int64_t j = 0; j < numSizes; ++j) { | ||
if (j == minusOneIdx) { | ||
auto divMap = | ||
AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); | ||
outputShape.push_back(affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, divMap, {inputDimSize, knownProduct})); | ||
} else { | ||
outputShape.push_back(sizeToOFR(reassocSizes[j])); | ||
} | ||
} | ||
} else { | ||
OpFoldResult inputDimSize = | ||
(inputTensorType.getSizes()[i] != Torch::kUnknownSize) | ||
? rewriter.getIndexAttr(inputTensorType.getSizes()[i]) | ||
: OpFoldResult(inputSizes[i]); | ||
outputShape.push_back(inputDimSize); | ||
} | ||
} | ||
|
||
// Originally I was doing: | ||
// expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, | ||
// adaptor.getSelf(), reassociations, outputShape).getResult(); But with | ||
// that I was running into: error: 'tensor.expand_shape' op expected | ||
// dimension 0 of collapsed type to be dynamic since one or more of the | ||
// corresponding dimensions in the expanded type is dynamic %4491 = | ||
// torch.aten.as_strided %4488, %4489, %4490, %int0_462 : | ||
// !torch.vtensor<[2,4096,5120],f16>, !torch.list<int>, !torch.list<int>, | ||
// !torch.int -> !torch.vtensor<[2,4096,2560],f16> | ||
// /home/rdhar/expand-shape-bug/iree/iree-model-benchmark/sdxl/int8-model/base_ir/stable_diffusion_xl_base_1_0_scheduled_unet_bs1_64_1024x1024_i8.mlir:13071:13: | ||
// note: see current operation: %17734 = "tensor.expand_shape"(%17730) | ||
// <{reassociation = [[0, 1, 2]], static_output_shape = array<i64: 2, 1, | ||
// 1>}> : (tensor<2xi64>) -> tensor<?x1x1xi64> So there is this really | ||
// ugly code to handle the types... but it kind of defeats all the code | ||
// above. | ||
SmallVector<int64_t> resultShape; | ||
for (OpFoldResult ofr : outputShape) { | ||
if (auto attr = ofr.dyn_cast<Attribute>()) { | ||
resultShape.push_back(cast<IntegerAttr>(attr).getInt()); | ||
} else { | ||
resultShape.push_back(ShapedType::kDynamic); | ||
} | ||
} | ||
raayandhar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
auto resultType = | ||
RankedTensorType::get(resultShape, expandTy.getElementType()); | ||
expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, | ||
adaptor.getSelf(), reassociations, | ||
outputShape) | ||
.getResult(); | ||
|
||
if (resultType != expandTy) { | ||
expand = | ||
rewriter.create<tensor::CastOp>(loc, expandTy, expand).getResult(); | ||
} | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this path still needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so. I tried without, doesn't seem to be possible. Unless you had a specific patch in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see when it would need to create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be honest I'm not sure either outside of the explicit check, for which I've added a case for now. But otherwise I'm not sure if it's possible to simplify the logic. Sometime last week I tested trying to combine the two and it didn't work. |
||
reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), | ||
reassocSizes); | ||
|
@@ -745,6 +838,7 @@ class ConvertAtenUnflattenIntOp | |
shapeValue) | ||
.getResult(); | ||
} | ||
|
||
rewriter.replaceOp(op, expand); | ||
return success(); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.