Skip to content

Commit d484683

Browse files
Max191keshavvinayak01
authored andcommitted
[DispatchCreation] Fuse reshape op chains along with set_encoding ops (iree-org#21365)
There can be reshape ops in between set_encoding ops and their producer dispatch region ops when we fuse encoding ops into dispatches, so this PR pulls any reshape op chains into the producer dispatch along with the set_encoding op. This enables more producer fusions for set_encoding in the data tiling fusion path. --------- Signed-off-by: Max Dawkins <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent cc0e843 commit d484683

File tree

5 files changed

+105
-61
lines changed

5 files changed

+105
-61
lines changed

compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
99
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1010
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
11+
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1112
#include "iree/compiler/DispatchCreation/Passes.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -83,16 +84,18 @@ struct FuseEncodingOpsIntoDispatchRegionsPass
8384

8485
for (IREE::Encoding::SetEncodingOp encodingOp : encodingOps) {
8586
OpOperand &operand = encodingOp.getSourceMutable();
86-
auto producerDispatch =
87-
operand.get().getDefiningOp<IREE::Flow::DispatchRegionOp>();
87+
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
88+
producerChain = getProducerDispatchValueAndOpChain(operand.get());
8889
// Nothing to fuse with, so wrap the `encodingOp` in its own dispatch.
89-
if (!producerDispatch) {
90+
if (!producerChain) {
9091
continue;
9192
}
9293

9394
// Find producer operation inside of the dispatch region to determine if
9495
// fusion is possible.
95-
auto result = cast<OpResult>(operand.get());
96+
OpResult result = producerChain->first;
97+
auto producerDispatch =
98+
result.getDefiningOp<IREE::Flow::DispatchRegionOp>();
9699
auto dispatchReturnOp = cast<IREE::Flow::ReturnOp>(
97100
producerDispatch.getBody().front().getTerminator());
98101
auto producerInRegion = dyn_cast<OpResult>(
@@ -107,10 +110,18 @@ struct FuseEncodingOpsIntoDispatchRegionsPass
107110
!isFusableWithSetEncoding(producerInRegion.getOwner())) {
108111
continue;
109112
}
110-
// Fuse the `encodingOp` into the producer dispatch region.
111-
if (failed(moveFollowingOpIntoDispatchRegion(rewriter, encodingOp,
112-
producerDispatch))) {
113-
return signalPassFailure();
113+
// Fuse the `encodingOp` and the producer chain into the dispatch.
114+
SmallVector<Operation *> dispatchConsumers(
115+
llvm::reverse(producerChain->second));
116+
dispatchConsumers.push_back(encodingOp);
117+
for (Operation *consumer : dispatchConsumers) {
118+
FailureOr<IREE::Flow::DispatchRegionOp> fusedDispatch =
119+
moveFollowingOpIntoDispatchRegion(rewriter, consumer,
120+
producerDispatch);
121+
if (failed(fusedDispatch)) {
122+
return signalPassFailure();
123+
}
124+
producerDispatch = fusedDispatch.value();
114125
}
115126
}
116127

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
1111
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
12+
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1213
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1314
#include "mlir/Analysis/SliceAnalysis.h"
1415
#include "mlir/Analysis/TopologicalSortUtils.h"
1516
#include "mlir/Dialect/Linalg/IR/Linalg.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1618
#include "mlir/Transforms/RegionUtils.h"
1719

1820
namespace mlir::iree_compiler::DispatchCreation {
@@ -177,4 +179,53 @@ LogicalResult moveOperandDefs(RewriterBase &rewriter,
177179
return success();
178180
}
179181

182+
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
183+
getProducerDispatchValueAndOpChain(Value operand) {
184+
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
185+
if (!operandType || operandType.getRank() == 0) {
186+
return std::nullopt;
187+
}
188+
189+
SmallVector<Operation *> opChain;
190+
auto producerValue = dyn_cast<OpResult>(operand);
191+
while (producerValue &&
192+
!isa<IREE::Flow::DispatchRegionOp>(producerValue.getOwner())) {
193+
if (!llvm::hasSingleElement(producerValue.getUses())) {
194+
return std::nullopt;
195+
}
196+
197+
// If it is an operation that we want to look past, add it to the chain
198+
// and update the `producerValue`.
199+
Operation *currOperation = producerValue.getOwner();
200+
if (isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(currOperation)) {
201+
opChain.push_back(currOperation);
202+
producerValue = dyn_cast<OpResult>(currOperation->getOperand(0));
203+
continue;
204+
}
205+
206+
// Conservative, bail out.
207+
return std::nullopt;
208+
}
209+
210+
if (!producerValue) {
211+
return std::nullopt;
212+
}
213+
214+
auto producerDispatch =
215+
dyn_cast<IREE::Flow::DispatchRegionOp>(producerValue.getOwner());
216+
// TODO(MaheshRavishankar): Multi-result producer dispatches can be supported.
217+
// Will require to move the consumer dispatch immediately after the producer
218+
// instead of what is done below and move other operands of the consumer
219+
// dispatch before the producer dispatch.
220+
if (!producerDispatch ||
221+
!llvm::hasSingleElement(producerDispatch.getBody()) ||
222+
producerDispatch->getNumResults() != 1) {
223+
return std::nullopt;
224+
}
225+
if (!llvm::hasSingleElement(producerValue.getUses())) {
226+
return std::nullopt;
227+
}
228+
return std::make_pair(producerValue, opChain);
229+
}
230+
180231
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/FusionUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,11 @@ LogicalResult moveOperandDefs(RewriterBase &rewriter,
3434
DominanceInfo &dominanceInfo,
3535
ArrayRef<Operation *> ignoreOperations = {});
3636

37+
/// Returns the closest producer dispatch region op result and the chain of
38+
/// operations being looked past during the traversal to find the producer
39+
/// dispatch. Returns std::nullopt if the dispatch or any ops in the chain have
40+
/// multiple uses.
41+
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
42+
getProducerDispatchValueAndOpChain(Value operand);
43+
3744
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1717
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
1818
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
19+
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1920
#include "iree/compiler/DispatchCreation/Passes.h"
2021
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2122
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -310,59 +311,6 @@ struct FoldFillWithSetEncoding final
310311
//===---------------------------------------------------------------------===//
311312
// Set padding encodings
312313
//===---------------------------------------------------------------------===//
313-
314-
// Utility to return the producer dispatch region op result and the chain of
315-
// operations being looked past during the traversal to find the producer
316-
// dispatch.
317-
static std::optional<std::pair<OpResult, SmallVector<Operation *>>>
318-
getProducerDispatchValueAndOpChain(Value operand) {
319-
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
320-
if (!operandType || operandType.getRank() == 0) {
321-
return std::nullopt;
322-
}
323-
324-
SmallVector<Operation *> opChain;
325-
auto producerValue = dyn_cast<OpResult>(operand);
326-
while (producerValue &&
327-
!isa<IREE::Flow::DispatchRegionOp>(producerValue.getOwner())) {
328-
if (!llvm::hasSingleElement(producerValue.getUses())) {
329-
return std::nullopt;
330-
}
331-
332-
// If it is an operation that we want to look past, add it to the chain
333-
// and update the `producerValue`.
334-
Operation *currOperation = producerValue.getOwner();
335-
if (isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(currOperation)) {
336-
opChain.push_back(currOperation);
337-
producerValue = dyn_cast<OpResult>(currOperation->getOperand(0));
338-
continue;
339-
}
340-
341-
// Conservative, bail out.
342-
return std::nullopt;
343-
}
344-
345-
if (!producerValue) {
346-
return std::nullopt;
347-
}
348-
349-
auto producerDispatch =
350-
dyn_cast<IREE::Flow::DispatchRegionOp>(producerValue.getOwner());
351-
// TODO(MaheshRavishankar): Multi-result producer dispatches can be supported.
352-
// Will require to move the consumer dispatch immediately after the producer
353-
// instead of what is done below and move other operands of the consumer
354-
// dispatch before the producer dispatch.
355-
if (!producerDispatch ||
356-
!llvm::hasSingleElement(producerDispatch.getBody()) ||
357-
producerDispatch->getNumResults() != 1) {
358-
return std::nullopt;
359-
}
360-
if (!llvm::hasSingleElement(producerValue.getUses())) {
361-
return std::nullopt;
362-
}
363-
return std::make_pair(producerValue, opChain);
364-
}
365-
366314
struct PaddedValue {
367315
Value paddedValue;
368316
SmallVector<Value> dynamicDims;

compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,30 @@ util.func public @multi_encoding_fusion_dynamic(%arg0: tensor<?x?x?xf32>, %d0: i
187187
// CHECK: flow.return %[[SET_ENCODING]] :
188188
// CHECK: }
189189
// CHECK: util.return %[[DISPATCH]], %[[DISPATCH]]
190+
191+
// -----
192+
193+
#encoding = #iree_encoding.testing<>
194+
util.func public @reshape_fusion(%arg0: tensor<32x32xf32>) -> tensor<16x64xf32, #encoding> {
195+
%cst = arith.constant 0.000000e+00 : f32
196+
%0 = tensor.empty() : tensor<32x32xf32>
197+
%1 = flow.dispatch.region -> (tensor<32x32xf32>) {
198+
%3 = linalg.add ins(%arg0, %arg0 : tensor<32x32xf32>, tensor<32x32xf32>)
199+
outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
200+
flow.return %3 : tensor<32x32xf32>
201+
}
202+
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<32x32xf32> into tensor<1024xf32>
203+
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [16, 64] : tensor<1024xf32> into tensor<16x64xf32>
204+
%2 = iree_encoding.set_encoding %expanded : tensor<16x64xf32> -> tensor<16x64xf32, #encoding>
205+
util.return %2 : tensor<16x64xf32, #encoding>
206+
}
207+
// CHECK: #[[$ENCODING:.+]] = #iree_encoding.testing<>
208+
// CHECK-LABEL: @reshape_fusion
209+
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region -> (tensor<16x64xf32, #[[$ENCODING]]>)
210+
// CHECK: linalg.add
211+
// CHECK: tensor.collapse_shape
212+
// CHECK: tensor.expand_shape
213+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
214+
// CHECK: flow.return %[[SET_ENCODING]] :
215+
// CHECK: }
216+
// CHECK: util.return %[[DISPATCH0]] : tensor<16x64xf32, #[[$ENCODING]]>

0 commit comments

Comments
 (0)