|
14 | 14 | #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H |
15 | 15 | #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H |
16 | 16 |
|
| 17 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
17 | 18 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
18 | 19 | #include "mlir/IR/OpImplementation.h" |
19 | 20 | #include "mlir/IR/PatternMatch.h" |
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> { |
305 | 306 | rewriter.replaceOpWithNewOp<CollapseOpTy>( |
306 | 307 | collapseOp, resultType, expandOp.getSrc(), composedReassociation); |
307 | 308 | } else if (srcRank < resultRank) { |
| 309 | + // Compute the dynamic output shape for the new expand_shape op. |
| 310 | + Location loc = collapseOp.getLoc(); |
| 311 | + SmallVector<OpFoldResult> origOutputShape = |
| 312 | + expandOp.getMixedOutputShape(); |
| 313 | + SmallVector<OpFoldResult> newOutputShape; |
| 314 | + for (auto indices : collapseOp.getReassociationIndices()) { |
| 315 | + int64_t numStaticElems = 1; |
| 316 | + SmallVector<Value> dynamicSizes; |
| 317 | + for (auto idx : indices) { |
| 318 | + OpFoldResult size = origOutputShape[idx]; |
| 319 | + if (auto maybeCst = getConstantIntValue(size)) { |
| 320 | + numStaticElems *= maybeCst.value(); |
| 321 | + continue; |
| 322 | + } |
| 323 | + dynamicSizes.push_back(cast<Value>(size)); |
| 324 | + } |
| 325 | + if (dynamicSizes.empty()) { |
| 326 | + newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems)); |
| 327 | + continue; |
| 328 | + } |
| 329 | + |
| 330 | + // There is at least one dynamic size, so we can intialize `result` to |
| 331 | + // the first dynamic size. |
| 332 | + Value result = dynamicSizes[0]; |
| 333 | + for (auto v : llvm::drop_begin(dynamicSizes)) |
| 334 | + result = rewriter.create<arith::MulIOp>(loc, result, v); |
| 335 | + if (numStaticElems != 1) { |
| 336 | + result = rewriter.create<arith::MulIOp>( |
| 337 | + loc, result, |
| 338 | + rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems)); |
| 339 | + } |
| 340 | + newOutputShape.push_back(result); |
| 341 | + } |
308 | 342 | rewriter.replaceOpWithNewOp<ExpandOpTy>( |
309 | | - collapseOp, resultType, expandOp.getSrc(), composedReassociation); |
| 343 | + collapseOp, resultType, expandOp.getSrc(), composedReassociation, |
| 344 | + newOutputShape); |
310 | 345 | } else { |
311 | 346 | // Collapses/expansions that do not change the rank are not allowed. Use |
312 | 347 | // a cast instead. |
|
0 commit comments