-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Linalg] Fix insert_slice fusion with rank reduction #130961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a370cd2
ce06732
b51029d
ce2ef6a
b19b649
49deedf
cf20e80
42d8959
6fc320d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
| #include "mlir/Dialect/Tensor/Utils/Utils.h" | ||
| #include "mlir/IR/AffineExpr.h" | ||
| #include "mlir/IR/AffineMap.h" | ||
| #include "mlir/IR/Dominance.h" | ||
|
|
@@ -26,6 +27,7 @@ | |
| #include "mlir/Transforms/RegionUtils.h" | ||
| #include "llvm/ADT/MapVector.h" | ||
| #include "llvm/ADT/ScopeExit.h" | ||
| #include "llvm/ADT/SmallBitVector.h" | ||
| #include "llvm/Support/CommandLine.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
|
|
@@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { | |
| return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); | ||
| } | ||
|
|
||
| /// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor | ||
| /// `from`. | ||
| tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from, | ||
RoboTux marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const llvm::SmallBitVector &dropDims) { | ||
| auto fromType = cast<ShapedType>(from.getType()); | ||
| assert(fromType.getRank() == dropDims.size()); | ||
RoboTux marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| SmallVector<ReassociationIndices, 2> reassocIdxsVec; | ||
RoboTux marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ReassociationIndices reassocIdxs; | ||
RoboTux marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| bool foundKeptDim = false; | ||
| for (int dim = 0; dim < fromType.getRank(); dim++) { | ||
| if (!dropDims.test(dim)) { | ||
| if (foundKeptDim) { | ||
| reassocIdxsVec.push_back(reassocIdxs); | ||
| reassocIdxs.clear(); | ||
| } | ||
| foundKeptDim = true; | ||
| } | ||
| reassocIdxs.push_back(dim); | ||
| } | ||
| if (!reassocIdxs.empty()) | ||
| reassocIdxsVec.push_back(reassocIdxs); | ||
| return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec); | ||
| } | ||
|
|
||
| FailureOr<FusionInfo> | ||
| mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, | ||
| OpOperand &consumerOpOperand) { | ||
|
|
@@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, | |
| << "\nNot fusable, not an extract_slice op: " << inputTensor); | ||
| return failure(); | ||
| } | ||
| llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); | ||
RoboTux marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // If producer is already in the same block as consumer, we are done. | ||
| if (consumerOpOperand.get().getParentBlock() == | ||
|
|
@@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, | |
| consumerOpOperand); | ||
|
|
||
| // Replace use. | ||
| Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); | ||
| Type consumerType = consumerOpOperand.get().getType(); | ||
| // Rank-reduction occured as part of the extract_slice. | ||
|
||
| if (cast<ShapedType>(consumerType).getRank() != | ||
| cast<ShapedType>(def.getType()).getRank()) | ||
| def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims); | ||
| // Canonicalizations are not guaranteed to have happened before constructing | ||
| // `fusedProducer`. In the tensor case this can result in temporary type | ||
| // mismatches. Insert a `tensor.cast` op to propagate the transformation | ||
| // invariant that types are compatible. | ||
| Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); | ||
| Type consumerType = consumerOpOperand.get().getType(); | ||
| if (consumerType != def.getType()) | ||
| def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); | ||
| consumerOpOperand.set(def); | ||
|
|
||
RoboTux marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.