|
26 | 26 | #include "mlir/Support/LLVM.h" |
27 | 27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
28 | 28 | #include "mlir/Transforms/RegionUtils.h" |
29 | | -#include "llvm/ADT/STLExtras.h" |
30 | | -#include "llvm/Support/LogicalResult.h" |
31 | 29 | #include <optional> |
32 | 30 | #include <utility> |
33 | 31 |
|
@@ -1169,19 +1167,10 @@ struct FoldReshapeWithProducerPadOpByExpansion |
1169 | 1167 | OpFoldResult h = high[inDimIdx]; |
1170 | 1168 |
|
1171 | 1169 | if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { |
1172 | | - auto srcType = cast<RankedTensorType>(padOp.getSource().getType()); |
1173 | | - int64_t originalSize = srcType.getDimSize(inDimIdx); |
1174 | | - |
1175 | | - OpFoldResult originalSizeOFR; |
1176 | | - if (originalSize == ShapedType::kDynamic) { |
1177 | | - Value orgSizeVal = |
1178 | | - rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx); |
1179 | | - originalSizeOFR = orgSizeVal; |
1180 | | - } else { |
1181 | | - originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); |
1182 | | - } |
1183 | 1170 | assert(reInd.size() == 1 && "expected single dimension"); |
1184 | | - expandedShape[reInd[0]] = originalSizeOFR; |
| 1171 | + expandedShape[reInd[0]] = |
| 1172 | + tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx); |
| 1173 | + ; |
1185 | 1174 | } |
1186 | 1175 | } |
1187 | 1176 |
|
@@ -1288,18 +1277,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing |
1288 | 1277 | OpFoldResult h = high[reInd[0]]; |
1289 | 1278 |
|
1290 | 1279 | if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { |
1291 | | - auto srcType = cast<RankedTensorType>(padOp.getSource().getType()); |
1292 | | - int64_t originalSize = srcType.getDimSize(reInd[0]); |
1293 | | - |
1294 | | - OpFoldResult originalSizeOFR; |
1295 | | - if (originalSize == ShapedType::kDynamic) { |
1296 | | - Value orgSizeVal = |
1297 | | - rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]); |
1298 | | - originalSizeOFR = orgSizeVal; |
1299 | | - } else { |
1300 | | - originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); |
1301 | | - } |
1302 | | - collapsedShape[inDimIdx] = originalSizeOFR; |
| 1280 | + collapsedShape[inDimIdx] = |
| 1281 | + tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]); |
1303 | 1282 | } |
1304 | 1283 | } |
1305 | 1284 |
|
|
0 commit comments