Skip to content

Commit 17a2447

Browse files
committed
fix upon review
1 parent d8ca036 commit 17a2447

File tree

1 file changed

+5
-26
lines changed

1 file changed

+5
-26
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
#include "mlir/Support/LLVM.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2828
#include "mlir/Transforms/RegionUtils.h"
29-
#include "llvm/ADT/STLExtras.h"
30-
#include "llvm/Support/LogicalResult.h"
3129
#include <optional>
3230
#include <utility>
3331

@@ -1169,19 +1167,10 @@ struct FoldReshapeWithProducerPadOpByExpansion
11691167
OpFoldResult h = high[inDimIdx];
11701168

11711169
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
1172-
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
1173-
int64_t originalSize = srcType.getDimSize(inDimIdx);
1174-
1175-
OpFoldResult originalSizeOFR;
1176-
if (originalSize == ShapedType::kDynamic) {
1177-
Value orgSizeVal =
1178-
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
1179-
originalSizeOFR = orgSizeVal;
1180-
} else {
1181-
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
1182-
}
11831170
assert(reInd.size() == 1 && "expected single dimension");
1184-
expandedShape[reInd[0]] = originalSizeOFR;
1171+
expandedShape[reInd[0]] =
1172+
tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
1173+
;
11851174
}
11861175
}
11871176

@@ -1288,18 +1277,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12881277
OpFoldResult h = high[reInd[0]];
12891278

12901279
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
1291-
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
1292-
int64_t originalSize = srcType.getDimSize(reInd[0]);
1293-
1294-
OpFoldResult originalSizeOFR;
1295-
if (originalSize == ShapedType::kDynamic) {
1296-
Value orgSizeVal =
1297-
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
1298-
originalSizeOFR = orgSizeVal;
1299-
} else {
1300-
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
1301-
}
1302-
collapsedShape[inDimIdx] = originalSizeOFR;
1280+
collapsedShape[inDimIdx] =
1281+
tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
13031282
}
13041283
}
13051284

0 commit comments

Comments
 (0)