Skip to content

Commit 9ee8e08

Browse files
committed
fix upon review
1 parent 17a2447 commit 9ee8e08

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,17 +1150,8 @@ struct FoldReshapeWithProducerPadOpByExpansion
11501150
}
11511151

11521152
Location loc = expandOp.getLoc();
1153-
auto finalType = cast<RankedTensorType>(expandOp.getType());
1154-
ArrayRef<int64_t> finalShape = finalType.getShape();
1155-
1156-
SmallVector<OpFoldResult> expandedShape;
1157-
for (int64_t dimSize : finalShape) {
1158-
if (dimSize == ShapedType::kDynamic) {
1159-
expandedShape.push_back(OpFoldResult{});
1160-
} else {
1161-
expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
1162-
}
1163-
}
1153+
ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();
1154+
SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
11641155

11651156
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
11661157
OpFoldResult l = low[inDimIdx];
@@ -1260,8 +1251,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12601251
Location loc = collapseOp.getLoc();
12611252
auto resultType = collapseOp.getResultType();
12621253

1263-
auto finalType = cast<RankedTensorType>(collapseOp.getType());
1264-
ArrayRef<int64_t> finalShape = finalType.getShape();
1254+
ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
12651255

12661256
SmallVector<OpFoldResult> collapsedShape;
12671257
for (int64_t dimSize : finalShape) {

0 commit comments

Comments
 (0)