-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] Add shape propagation through tensor.pad #136681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
0d8c636
cf1e560
57ec657
737d4a4
d8ca036
17a2447
9ee8e08
3b91645
9c38ad5
0faf084
9cbd032
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,6 +26,8 @@ | |||||||||||||||||||||||||||||
| #include "mlir/Support/LLVM.h" | ||||||||||||||||||||||||||||||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||||||||||||||||||||||||||||
| #include "mlir/Transforms/RegionUtils.h" | ||||||||||||||||||||||||||||||
| #include "llvm/ADT/STLExtras.h" | ||||||||||||||||||||||||||||||
| #include "llvm/Support/LogicalResult.h" | ||||||||||||||||||||||||||||||
| #include <optional> | ||||||||||||||||||||||||||||||
| #include <utility> | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -1100,6 +1102,267 @@ class FoldPadWithProducerReshapeOpByExpansion | |||||||||||||||||||||||||||||
| ControlFusionFn controlFoldingReshapes; | ||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| bool isZero(OpFoldResult value) { | ||||||||||||||||||||||||||||||
| if (auto attr = dyn_cast<Attribute>(value)) { | ||||||||||||||||||||||||||||||
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) | ||||||||||||||||||||||||||||||
| return intAttr.getInt() == 0; | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| if (auto val = dyn_cast<Value>(value)) { | ||||||||||||||||||||||||||||||
| if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) { | ||||||||||||||||||||||||||||||
| if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue())) | ||||||||||||||||||||||||||||||
| return attr.getInt() == 0; | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| bool isConstantIntValue(OpFoldResult ofr, int64_t value) { | |
| return getConstantIntValue(ofr) == value; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, "folding" would mean that tensor.expand_shape disappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".
Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can use llvm::zip_equal
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | |
| OpFoldResult l = low[idx]; | |
| OpFoldResult h = high[idx]; | |
| for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use rewriter.notifyMatchFailure() like above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can use append
| for (size_t i = 0; i < reInd.size(); ++i) { | |
| newLow.push_back(low[idx]); | |
| newHigh.push_back(high[idx]); | |
| } | |
| newLow.append(reInd.size(), low[idx]); | |
| newHigh.append(reInd.size(), high[idx]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: expandOp.getResultType() will return a RankedTensorType, so you can do:
| auto finalType = cast<RankedTensorType>(expandOp.getType()); | |
| ArrayRef<int64_t> finalShape = finalType.getShape(); | |
| ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape(); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what you're looking for is expandOp.getMixedOutputShape().
| SmallVector<OpFoldResult> expandedShape; | |
| for (int64_t dimSize : finalShape) { | |
| if (dimSize == ShapedType::kDynamic) { | |
| expandedShape.push_back(OpFoldResult{}); | |
| } else { | |
| expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); | |
| } | |
| } | |
| SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/Users/ita/src/llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp:1257:59: error: no member named 'getMixedOutputShape' in 'mlir::tensor::CollapseShapeOp'
1257 | SmallVector<OpFoldResult> collapsedShape = collapseOp.getMixedOutputShape();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make separate PR for this support. I think CollapseOp may need getMixedOutputShape too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CollapseShape doesn't need to carry the output shape, so it won't have it, but you can use it for the expand_shape pattern.
Max191 marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: save a level of nesting with early continue, i.e.:
if (isZero(l) && isZero(h)) {
continue;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better not to apply this change, code is already simpler with your suggests and logic is simpler?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what you're looking for is tensor::getMixedSize:
| auto srcType = cast<RankedTensorType>(padOp.getSource().getType()); | |
| int64_t originalSize = srcType.getDimSize(inDimIdx); | |
| OpFoldResult originalSizeOFR; | |
| if (originalSize == ShapedType::kDynamic) { | |
| Value orgSizeVal = | |
| rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx); | |
| originalSizeOFR = orgSizeVal; | |
| } else { | |
| originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); | |
| } | |
| OpFoldResult originalSize = tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to still be unchanged too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when I apply this, I see test failure, but I'm currently trying to apply this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that was my problem. solved!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that the reInd should have a size of 1 from the previous matching, but I think the logic is more clear if you add an assert here that reInd.size() == 1, and then just do expandedShape[reInd[0]] = originalSizeOFR;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was necessary because some of the expandedShape were null right? I'm pretty sure this shouldn't be necessary if you use getMixedOutputShape as per my above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did I understand your comment correctly?
this can be reduced
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
if (dimSize == ShapedType::kDynamic &&
!isa<Value>(expandedShape[outDimIdx]) &&
!isa<Attribute>(expandedShape[outDimIdx])) {
expandedShape[outDimIdx] =
tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
}
}
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use decomposeMixedValues here:
| SmallVector<int64_t> staticExpandedShape; | |
| for (OpFoldResult dim : expandedShape) { | |
| if (auto attr = dyn_cast<Attribute>(dim)) { | |
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | |
| staticExpandedShape.push_back(intAttr.getInt()); | |
| } else { | |
| staticExpandedShape.push_back(ShapedType::kDynamic); | |
| } | |
| } else { | |
| staticExpandedShape.push_back(ShapedType::kDynamic); | |
| } | |
| } | |
| SmallVector<int64_t> staticExpandedShape; | |
| std::tie(staticExpandedShape, std::ignore) = decomposeMixedValues(expandedShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you also want to pass the mixed output shape here to use the correct builder:
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | |
| loc, | |
| RankedTensorType::get(staticExpandedShape, | |
| padOp.getSource().getType().getElementType()), | |
| padOp.getSource(), reassociations); | |
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | |
| loc, | |
| RankedTensorType::get(staticExpandedShape, | |
| padOp.getSource().getType().getElementType()), | |
| padOp.getSource(), reassociations, expandedShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use rewriter.replaceOpWithNewOp<tensor::PadOp>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: combine this loop with the loop above.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need the mixed sizes for the collapsed shape. It is only used for getting the new type, so you can just collect static sizes instead (SmallVector<int64_t>).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think these includes are needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like these are still here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, done.