Skip to content

Commit 178651a

Browse files
authored
[MLIR][SCF] Add loops as parameter to LoopTerminator callback when using CustomOp. (#161386)
This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.
1 parent 71d8ddc commit 178651a

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ struct SCFTilingOptions {
183183
ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
184184

185185
// Type of the callback function that generates the loop terminator.
186+
// - `loops` : generated loops from the GenerateLoopHeaderFn callback
186187
// - `tiledResults` : Tiles of the result computed for the iteration space
187188
// tile.
188189
// - `resultOffsets` : For each of the `tiledResults`, the offset at which
@@ -193,7 +194,8 @@ struct SCFTilingOptions {
193194
// tensor.
194195
// Returns the `CustomLoopHeaderInfo` object (described above)
195196
using GenerateLoopTerminatorFn = std::function<LogicalResult(
196-
RewriterBase &rewriter, Location loc, ValueRange tiledResults,
197+
RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
198+
ValueRange tiledResults,
197199
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
198200
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
199201
ValueRange destinationTensors)>;

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,8 @@ generateLoopNestUsingCustomOp(
665665
return failure();
666666
}
667667

668-
if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
669-
resultOffsets, resultSizes,
668+
if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
669+
tiledResults, resultOffsets, resultSizes,
670670
loopHeaderInfo->destinationTensors))) {
671671
return failure();
672672
}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,8 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
581581
};
582582

583583
scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
584-
[&](RewriterBase &rewriter, Location loc, ValueRange tiledResults,
584+
[&](RewriterBase &rewriter, Location loc,
585+
ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
585586
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
586587
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
587588
ValueRange destinationTensors) -> LogicalResult {

0 commit comments

Comments
 (0)