|
9 | 9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | 10 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
11 | 11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 12 | +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
12 | 13 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
13 | 14 | #include "mlir/IR/AffineMap.h" |
14 | 15 | #include "mlir/IR/Builders.h" |
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { |
1124 | 1125 | } |
1125 | 1126 | } // else dim.getIndex is a block argument to reshape->getBlock and |
1126 | 1127 | // dominates reshape |
1127 | | - } // Check condition 2 |
| 1128 | + } // Check condition 2 |
1128 | 1129 | else if (dim->getBlock() != reshape->getBlock() && |
1129 | 1130 | !dim.getIndex().getParentRegion()->isProperAncestor( |
1130 | 1131 | reshape->getParentRegion())) { |
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType( |
2525 | 2526 | srcType.getMemorySpace()); |
2526 | 2527 | } |
2527 | 2528 |
|
| 2529 | +MemRefType |
| 2530 | +CollapseShapeOp::inferCollapsedType(MemRefType type, |
| 2531 | + ArrayRef<AffineMap> reassociation) { |
| 2532 | + auto shape = type.getShape(); |
| 2533 | + SmallVector<int64_t, 4> newShape; |
| 2534 | + assert(isReassociationValid(reassociation) && "invalid reassociation"); |
| 2535 | + unsigned currentDim = 0; |
| 2536 | + for (AffineMap m : reassociation) { |
| 2537 | + unsigned dim = m.getNumResults(); |
| 2538 | + auto band = shape.slice(currentDim, dim); |
| 2539 | + int64_t size = 1; |
| 2540 | + if (llvm::is_contained(band, ShapedType::kDynamic)) |
| 2541 | + size = ShapedType::kDynamic; |
| 2542 | + else |
| 2543 | + for (unsigned d = 0; d < dim; ++d) |
| 2544 | + size *= shape[currentDim + d]; |
| 2545 | + newShape.push_back(size); |
| 2546 | + currentDim += dim; |
| 2547 | + } |
| 2548 | + return MemRefType::get(newShape, type.getElementType()); |
| 2549 | +} |
| 2550 | + |
| 2551 | +MemRefType CollapseShapeOp::inferCollapsedType( |
| 2552 | + MemRefType type, SmallVector<ReassociationIndices> reassociation) { |
| 2553 | + return inferCollapsedType( |
| 2554 | + type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( |
| 2555 | + type.getContext(), reassociation))); |
| 2556 | +} |
| 2557 | + |
2528 | 2558 | void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, |
2529 | 2559 | ArrayRef<ReassociationIndices> reassociation, |
2530 | 2560 | ArrayRef<NamedAttribute> attrs) { |
|
0 commit comments