2020#include " llvm/ADT/TypeSwitch.h"
2121#include " llvm/Support/Casting.h"
2222#include " llvm/Support/Debug.h"
23+ #include " mlir/Analysis/SliceAnalysis.h"
2324#include " mlir/Analysis/TopologicalSortUtils.h"
2425#include " mlir/Dialect/Affine/IR/AffineOps.h"
2526#include " mlir/Dialect/Arith/IR/Arith.h"
4344#include " mlir/Pass/Pass.h"
4445#include " mlir/Support/LLVM.h"
4546#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
47+ #include " mlir/Transforms/RegionUtils.h"
4648
4749#define DEBUG_TYPE " iree-dispatch-creation-form-dispatch-regions"
4850
@@ -145,6 +147,50 @@ class FusionGroup {
145147 // Insert `op` into the fusion group.
146148 void insert (Operation *op);
147149
150+ // / Returns true if `consumerOp` has a transitive dependency on the fusion
151+ // / group. This means that some transitive dependency of `consumerOp` (not in
152+ // / the fusion group) itself uses an operation in the fusion group. This is
153+ // / required for fusion because it must be legal to take a program slice that
154+ // / contains only the ops in the fusion group.
155+ bool
156+ hasTransitiveDependencyOnFusionGroup (Operation *consumerOp,
157+ DominanceInfo const &dominance) const {
158+ BackwardSliceOptions options;
159+ options.inclusive = true ;
160+ options.omitUsesFromAbove = false ;
161+ options.omitBlockArguments = true ;
162+ options.filter = [&](Operation *sliceBoundaryOp) {
163+ return !llvm::all_of (
164+ loopMaps.getArrayRef (), [&](std::pair<Operation *, AffineMap> pair) {
165+ return dominance.properlyDominates (sliceBoundaryOp, pair.first );
166+ });
167+ };
168+
169+ llvm::SetVector<Operation *> slice;
170+ auto populateSlice = [&](OpOperand *operand) {
171+ // It's okay if the consumer directly uses an operation in the fusion
172+ // group.
173+ if (loopMaps.contains (operand->get ().getDefiningOp ())) {
174+ return ;
175+ }
176+ LogicalResult result = getBackwardSlice (operand->get (), &slice, options);
177+ assert (result.succeeded () && " expected a backward slice" );
178+ (void )result;
179+ };
180+
181+ // Search all of the operands op `consumerOp` as well as all the values used
182+ // in its regions.
183+ mlir::visitUsedValuesDefinedAbove (consumerOp->getRegions (), populateSlice);
184+ for (OpOperand &operand : consumerOp->getOpOperands ()) {
185+ populateSlice (&operand);
186+ }
187+
188+ return llvm::any_of (loopMaps.getArrayRef (),
189+ [&](std::pair<Operation *, AffineMap> pair) {
190+ return slice.contains (pair.first );
191+ });
192+ }
193+
148194private:
149195 Operation *rootOp;
150196 // All operations to be fused with the root op. This does not include
@@ -435,6 +481,9 @@ getFusableUses(MLIRContext *context, Operation *op,
435481 if (isa<tensor::DimOp>(user)) {
436482 continue ;
437483 }
484+ if (op->getBlock () != user->getBlock ()) {
485+ continue ;
486+ }
438487 fusableUses.insert (&use);
439488 }
440489
@@ -667,6 +716,13 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
667716 continue ;
668717 }
669718
719+ // Ensure that fusing the consumer would not cause use-def violations.
720+ if (tracker.getFusionGroup (currRoot)
721+ .hasTransitiveDependencyOnFusionGroup (fusableUse->getOwner (),
722+ dominanceInfo)) {
723+ continue ;
724+ }
725+
670726 if (isFusableWithConsumer (*fusableUse, tracker, options)) {
671727 tracker.appendToFusionGroup (consumerOp, fusionGroup);
672728 workList.push_back (consumerOp);
@@ -957,7 +1013,8 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
9571013 auto newRegionOp =
9581014 movePrecedingOpsIntoDispatchRegion (rewriter, producer, regionOp);
9591015 if (failed (newRegionOp)) {
960- return producer->emitOpError (" failed to move producer into region" );
1016+ producer->emitWarning (" failed to move producer into region" );
1017+ continue ;
9611018 }
9621019 regionOp = *newRegionOp;
9631020 }
@@ -974,7 +1031,7 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
9741031 auto newRegionOp = IREE::Flow::moveFollowingOpIntoDispatchRegion (
9751032 rewriter, consumer, regionOp);
9761033 if (failed (newRegionOp)) {
977- continue ;
1034+ return consumer-> emitOpError ( " failed to move consumer into region " ) ;
9781035 }
9791036 regionOp = *newRegionOp;
9801037 }
0 commit comments