Skip to content

Commit b864909

Browse files
[mlir][SCF] Allow using a custom operation to generate loops with mlir::tileUsingSCF. (llvm#159506)
This change adds an option to use a custom operation to generate the inter-tile loops during tiling. When the loop type is set to `scf::SCFTilingOptions::LoopType::CustomOp`, the method `mlir::tileUsingSCF` provides two callback functions 1. First one to generate the header of the loop. 2. Second one to generate the terminator of the loop. These methods receive the information needed to generate the loops/terminator and expect to return information needed to generate the code for the intra-tile computation. See comments for more details. Presently this is adds support only for tiling. Subsequent commits will update this to add support for fusion as well. The PR is split into two commits. 1) The first commit is an NFC that just refactors the code (and cleans up some naming) to make it easier to add the support for custom loop operations. 2) The second commit adds the support for using a custom loop operation, as well as a test to exercise this path. Signed-off-by: MaheshRavishankar <[email protected]> --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 4fabe6f commit b864909

File tree

5 files changed

+659
-210
lines changed

5 files changed

+659
-210
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
3333

3434
/// Options to use to control tiling.
3535
struct SCFTilingOptions {
36+
/// Specify which loop construct to use for tile and fuse.
37+
enum class LoopType { ForOp, ForallOp, CustomOp };
38+
LoopType loopType = LoopType::ForOp;
39+
SCFTilingOptions &setLoopType(LoopType type) {
40+
loopType = type;
41+
return *this;
42+
}
43+
3644
/// Computation function that returns the tile sizes to use for each loop.
3745
/// Returning a tile size of zero implies no tiling for that loop. If the
3846
/// size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
5058
/// proper interaction with folding.
5159
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
5260

61+
/// The interchange vector to reorder the tiled loops.
62+
SmallVector<int64_t> interchangeVector = {};
63+
SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
64+
interchangeVector = llvm::to_vector(interchange);
65+
return *this;
66+
}
67+
68+
//-------------------------------------------------------------------------//
69+
// Options related to tiling using `scf.forall`.
70+
//-------------------------------------------------------------------------//
71+
5372
/// Computation function that returns the number of threads to use for
5473
/// each loop. Returning a num threads of zero implies no tiling for that
5574
/// loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
7089
/// function that computes num threads at the point they are needed.
7190
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
7291

73-
/// The interchange vector to reorder the tiled loops.
74-
SmallVector<int64_t> interchangeVector = {};
75-
SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
76-
interchangeVector = llvm::to_vector(interchange);
77-
return *this;
78-
}
79-
80-
/// Specify which loop construct to use for tile and fuse.
81-
enum class LoopType { ForOp, ForallOp };
82-
LoopType loopType = LoopType::ForOp;
83-
SCFTilingOptions &setLoopType(LoopType type) {
84-
loopType = type;
85-
return *this;
86-
}
87-
8892
/// Specify mapping of loops to devices. This is only respected when the loop
8993
/// constructs support such a mapping (like `scf.forall`). Will be ignored
9094
/// when using loop constructs that dont support such a mapping (like
@@ -117,6 +121,98 @@ struct SCFTilingOptions {
117121
reductionDims.insert(dims.begin(), dims.end());
118122
return *this;
119123
}
124+
125+
//-------------------------------------------------------------------------//
126+
// Options related to tiling using custom loop.
127+
//-------------------------------------------------------------------------//
128+
129+
// For generating the inter-tile loops using a custom loop, two callback
130+
// functions are needed
131+
// 1. That generates the "loop header", i.e. the loop that iterates over the
132+
// different tiles.
133+
// 2. That generates the loop terminator
134+
//
135+
// For `scf.forall` case the call back to generate loop header would generate
136+
//
137+
// ```mlir
138+
// scf.forall (...) = ... {
139+
// ..
140+
// }
141+
// ```
142+
//
143+
// and the call back to generate the loop terminator would generate the
144+
// `scf.in_parallel` region
145+
//
146+
// ```mlir
147+
// scf.forall (...) = ... {
148+
// scf.in_parallel {
149+
// tensor.parallel_insert_slice ...
150+
// }
151+
// }
152+
// ```
153+
//
154+
155+
// Information that is to be returned by the callback to generate the loop
156+
// header needed for the rest of the tiled codegeneration.
157+
// - `loops`: The generated loops
158+
// - `tileOffset`: The values that represent the offset of the iteration space
159+
// tile
160+
// - `tileSizes` : The values that represent the size of the iteration space
161+
// tile.
162+
// - `destinationTensors` : The tensors to use as destinations during tiling.
163+
struct CustomLoopHeaderInfo {
164+
SmallVector<LoopLikeOpInterface> loops;
165+
SmallVector<OpFoldResult> tileOffset;
166+
SmallVector<OpFoldResult> tileSizes;
167+
SmallVector<Value> destinationTensors;
168+
};
169+
170+
// Type of the callback function that generates the loop headers.
171+
// - `loopRanges` : Values that represent the full size of the iteration space
172+
// being tiled.
173+
// - `giveTileSizes` : The tile sizes that are to be used to tile the
174+
// iteration
175+
// space.
176+
// - `destinationTensors` : The tensors to use as destinations for the results
177+
// of the tiled loop for loops that implement
178+
// `DestinationStyleOpInterface`.
179+
// Returns the `CustomLoopHeaderInfo` object (described above). it is expected
180+
// that this function sets the insertion point of `rewriter` to the program
181+
// point where the intra-tile loop computation is to be generated.
182+
using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
183+
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
184+
ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
185+
186+
// Type of the callback function that generates the loop terminator.
187+
// - `tiledResults` : Tiles of the result computed for the iteration space
188+
// tile
189+
// - `resultOffsets` : For each of the `tiledResults`, the offset at which
190+
// the result tile is to be "inserted" back into the
191+
// destination tensor.
192+
// - `resultSizes` : For each of the `tiledResults`, the size of the result
193+
// tile
194+
// that is to be "inserted" back into the destination
195+
// tensor.
196+
// Returns the `CustomLoopHeaderInfo` object (described above)
197+
using GenerateLoopTerminatorFn = std::function<LogicalResult(
198+
RewriterBase &rewriter, Location loc, ValueRange tiledResults,
199+
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200+
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
201+
ValueRange destinationTensors)>;
202+
203+
// Callback function to generate the inter-tile loop header.
204+
GenerateLoopHeaderFn generateLoopHeaderFn = nullptr;
205+
// Callback function to generate the inter-tile loop terminator.
206+
GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr;
207+
// Helper function to set the callbacks for inter-tile loop header and
208+
// terminator functions when using a custom operation for the loop.
209+
SCFTilingOptions &
210+
setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
211+
GenerateLoopTerminatorFn terminatorFn) {
212+
generateLoopHeaderFn = std::move(headerFn);
213+
generateLoopTerminatorFn = std::move(terminatorFn);
214+
return *this;
215+
}
120216
};
121217

122218
/// Transformation information returned after tiling.

0 commit comments

Comments
 (0)