@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
3333
3434// / Options to use to control tiling.
3535struct SCFTilingOptions {
36+ // / Specify which loop construct to use for tile and fuse.
37+ enum class LoopType { ForOp, ForallOp, CustomOp };
38+ LoopType loopType = LoopType::ForOp;
39+ SCFTilingOptions &setLoopType (LoopType type) {
40+ loopType = type;
41+ return *this ;
42+ }
43+
3644 // / Computation function that returns the tile sizes to use for each loop.
3745 // / Returning a tile size of zero implies no tiling for that loop. If the
3846 // / size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
5058 // / proper interaction with folding.
5159 SCFTilingOptions &setTileSizes (ArrayRef<OpFoldResult> tileSizes);
5260
61+ // / The interchange vector to reorder the tiled loops.
62+ SmallVector<int64_t > interchangeVector = {};
63+ SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
64+ interchangeVector = llvm::to_vector (interchange);
65+ return *this ;
66+ }
67+
68+ // -------------------------------------------------------------------------//
69+ // Options related to tiling using `scf.forall`.
70+ // -------------------------------------------------------------------------//
71+
5372 // / Computation function that returns the number of threads to use for
5473 // / each loop. Returning a num threads of zero implies no tiling for that
5574 // / loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
7089 // / function that computes num threads at the point they are needed.
7190 SCFTilingOptions &setNumThreads (ArrayRef<OpFoldResult> numThreads);
7291
73- // / The interchange vector to reorder the tiled loops.
74- SmallVector<int64_t > interchangeVector = {};
75- SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
76- interchangeVector = llvm::to_vector (interchange);
77- return *this ;
78- }
79-
80- // / Specify which loop construct to use for tile and fuse.
81- enum class LoopType { ForOp, ForallOp };
82- LoopType loopType = LoopType::ForOp;
83- SCFTilingOptions &setLoopType (LoopType type) {
84- loopType = type;
85- return *this ;
86- }
87-
8892 // / Specify mapping of loops to devices. This is only respected when the loop
8993 // / constructs support such a mapping (like `scf.forall`). Will be ignored
9094 // / when using loop constructs that dont support such a mapping (like
@@ -117,6 +121,96 @@ struct SCFTilingOptions {
117121 reductionDims.insert (dims.begin (), dims.end ());
118122 return *this ;
119123 }
124+
125+ // -------------------------------------------------------------------------//
126+ // Options related to tiling using custom loop.
127+ // -------------------------------------------------------------------------//
128+
129+ // For generating the inter-tile loops using a custom loop, two callback
130+ // functions are needed
131+ // 1. That generates the "loop header", i.e. the loop that iterates over the
132+ // different tiles.
133+ // 2. That generates the loop terminator
134+ //
135+ // For `scf.forall` case the call back to generate loop header would generate
136+ //
137+ // ```mlir
138+ // scf.forall (...) = ... {
139+ // ..
140+ // }
141+ // ```
142+ //
143+ // and the call back to generate the loop terminator would generate the
144+ // `scf.in_parallel` region
145+ //
146+ // ```mlir
147+ // scf.forall (...) = ... {
148+ // scf.in_parallel {
149+ // tensor.parallel_insert_slice ...
150+ // }
151+ // }
152+ // ```
153+ //
154+
155+ // Information that is to be returned by loop header callback needed for the
156+ // rest of the tiled codegeneration.
157+ // - `loops`: The generated loops
158+ // - `tileOffset`: The values that represent the offset of the iteration space
159+ // tile.
160+ // - `tileSizes` : The values that represent the size of the iteration space
161+ // tile.
162+ // - `destinationTensors` : The tensors to use as destinations during tiling.
163+ struct CustomLoopHeaderInfo {
164+ SmallVector<LoopLikeOpInterface> loops;
165+ SmallVector<OpFoldResult> tileOffset;
166+ SmallVector<OpFoldResult> tileSizes;
167+ SmallVector<Value> destinationTensors;
168+ };
169+
170+ // Type of the callback function that generates the loop headers.
171+ // - `loopRanges` : Values that represent the full size of the iteration space
172+ // being tiled.
173+ // - `givenTileSizes` : The tile sizes that are to be used to tile the
174+ // iteration space.
175+ // - `destinationTensors` : The tensors to use as destinations for the results
176+ // of the tiled loop for loops that implement
177+ // `DestinationStyleOpInterface`.
178+ // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
179+ // that this function sets the insertion point of `rewriter` to the program
180+ // point where the intra-tile loop computation is to be generated.
181+ using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
182+ RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
183+ ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
184+
185+ // Type of the callback function that generates the loop terminator.
186+ // - `tiledResults` : Tiles of the result computed for the iteration space
187+ // tile.
188+ // - `resultOffsets` : For each of the `tiledResults`, the offset at which
189+ // the result tile is to be "inserted" back into the
190+ // destination tensor.
191+ // - `resultSizes` : For each of the `tiledResults`, the size of the result
192+ // tile that is to be "inserted" back into the destination
193+ // tensor.
194+ // Returns the `CustomLoopHeaderInfo` object (described above)
195+ using GenerateLoopTerminatorFn = std::function<LogicalResult(
196+ RewriterBase &rewriter, Location loc, ValueRange tiledResults,
197+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
198+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
199+ ValueRange destinationTensors)>;
200+
201+ // Callback function to generate the inter-tile loop header.
202+ GenerateLoopHeaderFn generateLoopHeaderFn = nullptr ;
203+ // Callback function to generate the inter-tile loop terminator.
204+ GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr ;
205+ // Helper function to set the callbacks for inter-tile loop header and
206+ // terminator functions when using a custom operation for the loop.
207+ SCFTilingOptions &
208+ setCustomLoopGenerationFns (GenerateLoopHeaderFn headerFn,
209+ GenerateLoopTerminatorFn terminatorFn) {
210+ generateLoopHeaderFn = std::move (headerFn);
211+ generateLoopTerminatorFn = std::move (terminatorFn);
212+ return *this ;
213+ }
120214};
121215
122216// / Transformation information returned after tiling.
0 commit comments