Skip to content

Commit 9c38ad5

Browse files
committed
fix upon review
1 parent 3b91645 commit 9c38ad5

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,32 +1239,21 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12391239
auto resultType = collapseOp.getResultType();
12401240

12411241
ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
1242-
1243-
SmallVector<OpFoldResult> collapsedShape;
1244-
for (int64_t dimSize : finalShape) {
1245-
if (dimSize == ShapedType::kDynamic) {
1246-
collapsedShape.push_back(OpFoldResult{});
1247-
} else {
1248-
collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
1249-
}
1250-
}
1251-
1242+
SmallVector<int64_t> collapsedShape(finalShape.begin(), finalShape.end());
12521243
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
12531244
OpFoldResult l = low[reInd[0]];
12541245
OpFoldResult h = high[reInd[0]];
1255-
12561246
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
1257-
collapsedShape[inDimIdx] =
1247+
auto mixedSize =
12581248
tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
1249+
auto dimSize = getConstantIntValue(mixedSize);
1250+
assert(dimSize.has_value() && "Expected static dimension");
1251+
collapsedShape[inDimIdx] = *dimSize;
12591252
}
12601253
}
12611254

1262-
SmallVector<int64_t> staticCollapsedShape;
1263-
std::tie(staticCollapsedShape, std::ignore) =
1264-
decomposeMixedValues(collapsedShape);
1265-
12661255
auto newCollapseType = RankedTensorType::get(
1267-
staticCollapsedShape, padOp.getSource().getType().getElementType());
1256+
collapsedShape, padOp.getSource().getType().getElementType());
12681257
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
12691258
loc, newCollapseType, padOp.getSource(), reassociations);
12701259

0 commit comments

Comments
 (0)