Skip to content

Commit 78ec7f2

Browse files
[DispatchCreation] Changes to dispatch region in preparation for horizontal fusion changes. (#19876)
Current dispatch region formation handles consumer fusion by making the consumer the root of the DAG moved into dispatches. For cases where we have more than one consumer that dont have a direct dependency, this approach does not work. This changes dispatch region formation to keep the root operation as is, and move in consumers into the dispatch iteratively. This required a few additional changes 1) Move the method `moveOperandDefs` into a utility function. 2) Changes to how the dynamic dims of results of `flow.dispatch.region` created are resolved. --------- Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: Ian Wood <[email protected]> Co-authored-by: Ian Wood <[email protected]>
1 parent eb58f82 commit 78ec7f2

File tree

9 files changed

+425
-95
lines changed

9 files changed

+425
-95
lines changed

compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) {
4545
void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op,
4646
InsertPoint previous) {
4747
IRRewriter::Listener::notifyOperationInserted(op, previous);
48-
if (isa<tensor::DimOp>(op))
48+
auto dimOp = dyn_cast<tensor::DimOp>(op);
49+
if (dimOp && isa<OpResult>(dimOp.getSource()))
4950
dimOps.insert(op);
5051
}
5152

@@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter,
6061
std::optional<int64_t> idx = dimOp.getConstantIndex();
6162
if (!idx.has_value())
6263
continue;
64+
65+
if (isa<BlockArgument>(dimOp.getSource())) {
66+
continue;
67+
}
68+
6369
// Only DimOps with ranked tensors are supported.
6470
auto tensorType =
6571
llvm::dyn_cast<RankedTensorType>(dimOp.getSource().getType());
6672
if (!tensorType)
6773
continue;
6874

75+
OpBuilder::InsertionGuard g(rewriter);
76+
rewriter.setInsertionPoint(dimOp);
6977
if (!tensorType.isDynamicDim(*idx)) {
7078
// Rewrite static dimension with constant.
71-
OpBuilder::InsertionGuard g(rewriter);
72-
rewriter.setInsertionPoint(dimOp);
7379
int64_t size = tensorType.getShape()[*idx];
7480
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(dimOp, size);
7581
continue;

compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
266266
// Value is an OpResult.
267267
Operation *op = value.getDefiningOp();
268268
OpResult opResult = llvm::cast<OpResult>(value);
269-
b.setInsertionPoint(op);
270269

271-
// Case 3: Value is tied. Reify the dimensions of the tied operand.
272-
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
273-
if (tiedOp) {
274-
Value tiedOperand = tiedOp.getTiedResultOperand(value);
275-
if (tiedOperand && tiedOperand.getType() == value.getType())
276-
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
277-
createTensorDimOps);
278-
}
279-
280-
// Case 4: Query ShapeAwareOpInterface.
270+
// Case 3: Query ShapeAwareOpInterface.
281271
auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
282272
if (shapeAwareOp) {
283273
ValueRange dims =
@@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
286276
return success();
287277
}
288278

279+
// Case 4: Value is tied. Reify the dimensions of the tied operand.
280+
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
281+
if (tiedOp) {
282+
Value tiedOperand = tiedOp.getTiedResultOperand(value);
283+
if (tiedOperand && tiedOperand.getType() == value.getType())
284+
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
285+
/*createTensorDimOps=*/true);
286+
}
287+
289288
// Case 5: Query ReifyRankedShapedTypeOpInterface.
290289
auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
291290
if (reifyShapeOp) {
@@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
308307
}
309308

310309
/// Reify the dynamic dimensions of the given value.
310+
/// Deprecated. Use `getOptimizedDynamicResultDims` instead.
311311
LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value,
312312
SmallVectorImpl<Value> &dynamicDims) {
313+
314+
OpBuilder::InsertionGuard g(b);
315+
if (auto op = value.getDefiningOp()) {
316+
b.setInsertionPoint(op);
317+
}
313318
return reifyDynamicResultDimsImpl(b, value, dynamicDims,
314319
/*createTensorDimOps=*/true);
315320
}
@@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter,
473478
rewriter.setInsertionPoint(target);
474479
SmallVector<Value> &dims =
475480
dispatchOpNewResultsDynamicDims.emplace_back();
476-
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
481+
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
477482
return target->emitOpError(
478483
"failed to reify dynamic dims of result to be yielded from "
479484
"dispatch region");
@@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
554559
for (auto [index, result] : llvm::enumerate(target->getResults())) {
555560
replacedValues.push_back(result);
556561
yieldedResults.push_back(clonedTarget->getResult(index));
557-
rewriter.setInsertionPoint(target);
562+
OpBuilder::InsertionGuard g1(rewriter);
563+
rewriter.setInsertionPoint(regionOp);
558564
SmallVector<Value> &dims = dispatchOpNewResultsDynamicDims.emplace_back();
559-
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
565+
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
560566
return target->emitOpError(
561567
"failed to reify dynamic dims of result to be yielded from "
562568
"dispatch region");

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) {
692692
/// Recognize an operation that is horizontally fused contraction.
693693
/// TODO: The logic below is quite convoluted. Might be better
694694
/// off having a dedicated operation for this.
695-
bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) {
695+
bool isaHorizontallyFusedContraction(Operation *op) {
696+
auto linalgOp = dyn_cast_or_null<linalg::GenericOp>(op);
697+
if (!linalgOp) {
698+
return false;
699+
}
700+
696701
if (linalgOp->getNumResults() == 1) {
697702
return false;
698703
}

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op);
214214
/// Check if a given operation is a horizontally fused contraction operation.
215215
/// The expectation is that the LHS is common, and all the operands are
216216
/// different RHS.
217-
bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp);
217+
bool isaHorizontallyFusedContraction(Operation *op);
218218

219219
} // namespace mlir::iree_compiler::IREE::LinalgExt
220220
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_

0 commit comments

Comments
 (0)