Skip to content

Commit ff13ad0

Browse files
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 1bb1eb6 commit ff13ad0

File tree

5 files changed

+63
-53
lines changed

5 files changed

+63
-53
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ struct FuseTilableForallConsumers final
282282
}
283283

284284
tensor::ParallelInsertSliceOp producerSlice;
285-
scf::ForallOp sliceOwner;
285+
LoopLikeOpInterface sliceOwner;
286286
Value fusionOperand;
287287
for (auto operand : dpsOp.getDpsInputs()) {
288288
auto forallProducer = operand.getDefiningOp<scf::ForallOp>();
@@ -320,7 +320,7 @@ struct FuseTilableForallConsumers final
320320
}
321321

322322
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
323-
scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice);
323+
scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice, {sliceOwner});
324324
if (failed(fuseConsumerResults)) {
325325
return failure();
326326
}

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ static bool areAllStaticLoopBounds(scf::ForallOp forallOp) {
237237

238238
/// Find dimensions of the loop that are unit-trip count and drop them from the
239239
/// distributed dimensions.
240-
static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
241-
scf::ForallOp forallOp) {
240+
static FailureOr<scf::ForallOp>
241+
dropUnitDistributedDims(RewriterBase &rewriter, scf::ForallOp forallOp) {
242242
SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
243243
SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
244244
SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();
@@ -261,7 +261,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
261261
}
262262
}
263263
if (droppedLoops.empty()) {
264-
return success();
264+
return forallOp;
265265
}
266266

267267
OpBuilder::InsertionGuard g(rewriter);
@@ -303,7 +303,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
303303
rewriter.mergeBlocks(oldLoopBody, newLoopBody, argReplacements);
304304

305305
rewriter.replaceOp(forallOp, newForallOp.getResults());
306-
return success();
306+
return newForallOp;
307307
}
308308

309309
//===---------------------------------------------------------------------===//
@@ -314,8 +314,9 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
314314
// Returns a list of new `tensor.extract_slice` ops with new fusion
315315
// opportunities, as well as the new surrounding `scf.forall` (because consumer
316316
// fusion replaces the loop).
317-
static std::pair<std::queue<Operation *>, scf::ForallOp>
318-
fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
317+
static std::queue<Operation *>
318+
fuseConsumers(RewriterBase &rewriter, Operation *tiledOp,
319+
MutableArrayRef<LoopLikeOpInterface> loops) {
319320
auto addCandidateSlices =
320321
[](Operation *fusedOp,
321322
std::queue<tensor::ParallelInsertSliceOp> &candidates) {
@@ -333,15 +334,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
333334
addCandidateSlices(tiledOp, candidates);
334335

335336
std::queue<Operation *> newFusionOpportunities;
336-
scf::ForallOp newLoop = tiledOp->getParentOfType<scf::ForallOp>();
337337
while (!candidates.empty()) {
338338

339339
// Traverse the slices in BFS fashion.
340340
tensor::ParallelInsertSliceOp candidateSliceOp = candidates.front();
341341
candidates.pop();
342342

343343
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
344-
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
344+
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp,
345+
loops);
345346
if (failed(fusedResult)) {
346347
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
347348
<< candidateSliceOp << "\n");
@@ -369,19 +370,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
369370
}
370371
}
371372
}
372-
// Store the new loop for follow up producer fusion.
373-
newLoop = tiledOp->getParentOfType<scf::ForallOp>();
374373
}
375374
}
376-
return std::make_pair(newFusionOpportunities, newLoop);
375+
return newFusionOpportunities;
377376
}
378377

379378
static void fuseProducersOfSlices(RewriterBase &rewriter,
380379
std::queue<Operation *> &worklist,
381380
scf::SCFTileAndFuseOptions &options,
382-
scf::ForallOp forallOp) {
383-
SmallVector<LoopLikeOpInterface> loops = {
384-
cast<LoopLikeOpInterface>(&*forallOp)};
381+
MutableArrayRef<LoopLikeOpInterface> loops) {
385382
while (!worklist.empty()) {
386383
auto candidateSlice = cast<tensor::ExtractSliceOp>(worklist.front());
387384
worklist.pop();
@@ -532,7 +529,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
532529

533530
// If the `tilableOp` is a `memref` op, then just tile the operation.
534531
SmallVector<LoopLikeOpInterface> tilingLoops;
535-
Operation *rootTiledOp = nullptr;
536532
if (tilableOp->getNumResults() == 0) {
537533
FailureOr<scf::SCFTilingResult> tilingResult =
538534
scf::tileUsingSCF(rewriter, tilableOp, tilingOptions);
@@ -554,7 +550,16 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
554550
rewriter.replaceAllUsesWith(origValue, replacement);
555551
}
556552
std::swap(tileAndFuseResult->loops, tilingLoops);
557-
rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
553+
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
554+
auto newFusionOpportunities =
555+
fuseConsumers(rewriter, rootTiledOp, tilingLoops);
556+
557+
// Because we restrict to at most a single tilable consumer for yielding
558+
// a replacement, no new fusion opportunities will yield a replacement,
559+
// meaning there is no need to run consumer fusion again afterwards.
560+
// TODO: run producer and consumer fusion in one worklist.
561+
fuseProducersOfSlices(rewriter, newFusionOpportunities, tileAndFuseOptions,
562+
tilingLoops);
558563
}
559564
if (!tilingLoops.empty()) {
560565
if (tilingLoops.size() != 1 || !isa<scf::ForallOp>(tilingLoops[0])) {
@@ -563,35 +568,24 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
563568
return signalPassFailure();
564569
}
565570

566-
auto forallOp = cast<scf::ForallOp>(tilingLoops[0]);
567-
if (failed(dropUnitDistributedDims(rewriter, forallOp))) {
568-
forallOp.emitOpError("failed to drop unit dimensions");
571+
auto forallOp =
572+
dropUnitDistributedDims(rewriter, cast<scf::ForallOp>(tilingLoops[0]));
573+
if (failed(forallOp)) {
574+
tilingLoops[0]->emitOpError("failed to drop unit dimensions");
569575
return signalPassFailure();
570576
}
571577

572-
if (rootTiledOp) {
573-
auto [newFusionOpportunities, newLoop] =
574-
fuseConsumers(rewriter, rootTiledOp);
575-
576-
// Because we restrict to at most a single tilable consumer for yielding
577-
// a replacement, no new fusion opportunities will yield a replacement,
578-
// meaning there is no need to run consumer fusion again afterwards.
579-
// TODO: run producer and consumer fusion in one worklist.
580-
fuseProducersOfSlices(rewriter, newFusionOpportunities,
581-
tileAndFuseOptions, newLoop);
582-
forallOp = newLoop;
583-
}
584-
585578
// Reorder the workgroups if the strategy is set to `transpose`.
586579
// This just transposes the first two dimensions of the workgroup i.e., the
587580
// #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y.
588581
// Only reorders if the loop bounds are static.
589582
if (transposeWorkgroup) {
590-
SmallVector<Attribute> mappingAttrs(forallOp.getMappingAttr().getValue());
583+
SmallVector<Attribute> mappingAttrs(
584+
forallOp->getMappingAttr().getValue());
591585
int64_t mappingSize = mappingAttrs.size();
592-
if (areAllStaticLoopBounds(forallOp) && mappingSize >= 2) {
586+
if (areAllStaticLoopBounds(*forallOp) && mappingSize >= 2) {
593587
std::swap(mappingAttrs[mappingSize - 1], mappingAttrs[mappingSize - 2]);
594-
forallOp.setMappingAttr(ArrayAttr::get(context, mappingAttrs));
588+
forallOp->setMappingAttr(ArrayAttr::get(context, mappingAttrs));
595589
}
596590
}
597591
}

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,7 @@ template <typename Range>
11931193
static LogicalResult
11941194
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
11951195
Range &&payloadOps,
1196+
MutableArrayRef<LoopLikeOpInterface> loops,
11961197
transform::TransformResults &transformResults) {
11971198
SmallVector<Operation *> originalConsumerOps;
11981199
SmallVector<Operation *> fusedConsumerOps;
@@ -1201,7 +1202,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
12011202
rewriter.setInsertionPoint(target);
12021203

12031204
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
1204-
scf::tileAndFuseConsumerOfSlice(rewriter, target);
1205+
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
12051206

12061207
if (failed(fuseConsumerResults))
12071208
return failure();
@@ -1222,9 +1223,18 @@ DiagnosedSilenceableFailure transform_dialect::FuseConsumerOp::apply(
12221223
transform::TransformRewriter &rewriter,
12231224
transform::TransformResults &transformResults,
12241225
transform::TransformState &state) {
1225-
LogicalResult result =
1226-
applyFuseConsumer(rewriter, getOperation(),
1227-
state.getPayloadOps(getTarget()), transformResults);
1226+
SmallVector<LoopLikeOpInterface> loops;
1227+
for (auto op : getLoops()) {
1228+
auto loopOp =
1229+
dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
1230+
if (!loopOp) {
1231+
return DiagnosedSilenceableFailure::definiteFailure();
1232+
}
1233+
loops.push_back(loopOp);
1234+
}
1235+
LogicalResult result = applyFuseConsumer(rewriter, getOperation(),
1236+
state.getPayloadOps(getTarget()),
1237+
loops, transformResults);
12281238
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
12291239
: DiagnosedSilenceableFailure::success();
12301240
}

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,14 @@ def FuseConsumerOp : Op<Transform_Dialect, "iree.fuse_consumer",
694694
}];
695695
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
696696

697-
let arguments =
698-
(ins TransformHandleTypeInterface:$target);
697+
let arguments =(ins
698+
TransformHandleTypeInterface:$target,
699+
Variadic<TransformHandleTypeInterface>:$loops);
699700
let results = (outs TransformHandleTypeInterface:$consumer,
700701
TransformHandleTypeInterface:$fused_consumer);
701702

702703
let assemblyFormat = [{
703-
$target attr-dict `:` functional-type(operands, results)
704+
$target `in` `(` $loops `)` attr-dict `:` functional-type(operands, results)
704705
}];
705706
}
706707

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ static void collectTiledAndFusedOps(Operation *rootOp,
5656
/// Tile the root operation and fuse the producers of the root operation.
5757
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
5858
/// operands. Returns the tiled operation to be used for fusing consumers.
59-
FailureOr<Operation *>
59+
static FailureOr<scf::SCFTileAndFuseResult>
6060
tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
6161
int64_t tilingLevel,
6262
bool onlyFuseProducerInputOperands) {
@@ -136,10 +136,11 @@ tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
136136
}
137137
}
138138

139-
return tiledResults->tiledAndFusedOps.front();
139+
return tiledResults;
140140
}
141141

142-
static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
142+
static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp,
143+
MutableArrayRef<LoopLikeOpInterface> loops) {
143144

144145
// Typically, the consumers of the tiled operation are slices of the
145146
// results of the tiled operation. These are expressed in IR using
@@ -169,7 +170,8 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
169170
candidates.pop();
170171

171172
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
172-
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
173+
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp,
174+
loops);
173175
if (failed(fusedResult)) {
174176
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
175177
<< candidateSliceOp << "\n");
@@ -196,14 +198,17 @@ static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
196198
int64_t tilingLevel,
197199
bool onlyFuseProducerInputOperands) {
198200

199-
FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
200-
rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);
201+
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
202+
tileRootAndFuseProducers(rewriter, rootOp, tilingLevel,
203+
onlyFuseProducerInputOperands);
201204

202-
if (failed(tiledOp))
205+
if (failed(tileAndFuseResult))
203206
return failure();
204207

205-
if (!onlyFuseProducerInputOperands)
206-
fuseConsumers(rewriter, tiledOp.value());
208+
if (!onlyFuseProducerInputOperands) {
209+
fuseConsumers(rewriter, tileAndFuseResult->tiledAndFusedOps.front(),
210+
tileAndFuseResult->loops);
211+
}
207212

208213
return success();
209214
}

0 commit comments

Comments
 (0)