@@ -33,14 +33,6 @@ using SCFTileSizeComputationFunction =
3333
3434// / Options to use to control tiling.
3535struct SCFTilingOptions {
36- // / Specify which loop construct to use for tile and fuse.
37- enum class LoopType { ForOp, ForallOp, CustomOp };
38- LoopType loopType = LoopType::ForOp;
39- SCFTilingOptions &setLoopType (LoopType type) {
40- loopType = type;
41- return *this ;
42- }
43-
4436 // / Computation function that returns the tile sizes to use for each loop.
4537 // / Returning a tile size of zero implies no tiling for that loop. If the
4638 // / size of the returned vector is smaller than the number of loops, the inner
@@ -58,17 +50,6 @@ struct SCFTilingOptions {
5850 // / proper interaction with folding.
5951 SCFTilingOptions &setTileSizes (ArrayRef<OpFoldResult> tileSizes);
6052
61- // / The interchange vector to reorder the tiled loops.
62- SmallVector<int64_t > interchangeVector = {};
63- SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
64- interchangeVector = llvm::to_vector (interchange);
65- return *this ;
66- }
67-
68- // -------------------------------------------------------------------------//
69- // Options related to tiling using `scf.forall`.
70- // -------------------------------------------------------------------------//
71-
7253 // / Computation function that returns the number of threads to use for
7354 // / each loop. Returning a num threads of zero implies no tiling for that
7455 // / loop. If the size of the returned vector is smaller than the number of
@@ -89,6 +70,21 @@ struct SCFTilingOptions {
8970 // / function that computes num threads at the point they are needed.
9071 SCFTilingOptions &setNumThreads (ArrayRef<OpFoldResult> numThreads);
9172
73+ // / The interchange vector to reorder the tiled loops.
74+ SmallVector<int64_t > interchangeVector = {};
75+ SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
76+ interchangeVector = llvm::to_vector (interchange);
77+ return *this ;
78+ }
79+
80+ // / Specify which loop construct to use for tile and fuse.
81+ enum class LoopType { ForOp, ForallOp };
82+ LoopType loopType = LoopType::ForOp;
83+ SCFTilingOptions &setLoopType (LoopType type) {
84+ loopType = type;
85+ return *this ;
86+ }
87+
9288 // / Specify mapping of loops to devices. This is only respected when the loop
9389 // / constructs support such a mapping (like `scf.forall`). Will be ignored
9490 // / when using loop constructs that dont support such a mapping (like
@@ -121,98 +117,6 @@ struct SCFTilingOptions {
121117 reductionDims.insert (dims.begin (), dims.end ());
122118 return *this ;
123119 }
124-
125- // -------------------------------------------------------------------------//
126- // Options related to tiling using custom loop.
127- // -------------------------------------------------------------------------//
128-
129- // For generating the inter-tile loops using a custom loop, two callback
130- // functions are needed
131- // 1. That generates the "loop header", i.e. the loop that iterates over the
132- // different tiles.
133- // 2. That generates the loop terminator
134- //
135- // For `scf.forall` case the call back to generate loop header would generate
136- //
137- // ```mlir
138- // scf.forall (...) = ... {
139- // ..
140- // }
141- // ```
142- //
143- // and the call back to generate the loop terminator would generate the
144- // `scf.in_parallel` region
145- //
146- // ```mlir
147- // scf.forall (...) = ... {
148- // scf.in_parallel {
149- // tensor.parallel_insert_slice ...
150- // }
151- // }
152- // ```
153- //
154-
155- // Information that is to be returned by the callback to generate the loop
156- // header needed for the rest of the tiled codegeneration.
157- // - `loops`: The generated loops
158- // - `tileOffset`: The values that represent the offset of the iteration space
159- // tile
160- // - `tileSizes` : The values that represent the size of the iteration space
161- // tile.
162- // - `destinationTensors` : The tensors to use as destinations during tiling.
163- struct CustomLoopHeaderInfo {
164- SmallVector<LoopLikeOpInterface> loops;
165- SmallVector<OpFoldResult> tileOffset;
166- SmallVector<OpFoldResult> tileSizes;
167- SmallVector<Value> destinationTensors;
168- };
169-
170- // Type of the callback function that generates the loop headers.
171- // - `loopRanges` : Values that represent the full size of the iteration space
172- // being tiled.
173- // - `giveTileSizes` : The tile sizes that are to be used to tile the
174- // iteration
175- // space.
176- // - `destinationTensors` : The tensors to use as destinations for the results
177- // of the tiled loop for loops that implement
178- // `DestinationStyleOpInterface`.
179- // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
180- // that this function sets the insertion point of `rewriter` to the program
181- // point where the intra-tile loop computation is to be generated.
182- using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
183- RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
184- ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
185-
186- // Type of the callback function that generates the loop terminator.
187- // - `tiledResults` : Tiles of the result computed for the iteration space
188- // tile
189- // - `resultOffsets` : For each of the `tiledResults`, the offset at which
190- // the result tile is to be "inserted" back into the
191- // destination tensor.
192- // - `resultSizes` : For each of the `tiledResults`, the size of the result
193- // tile
194- // that is to be "inserted" back into the destination
195- // tensor.
196- // Returns the `CustomLoopHeaderInfo` object (described above)
197- using GenerateLoopTerminatorFn = std::function<LogicalResult(
198- RewriterBase &rewriter, Location loc, ValueRange tiledResults,
199- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200- ArrayRef<SmallVector<OpFoldResult>> resultSizes,
201- ValueRange destinationTensors)>;
202-
203- // Callback function to generate the inter-tile loop header.
204- GenerateLoopHeaderFn generateLoopHeaderFn = nullptr ;
205- // Callback function to generate the inter-tile loop terminator.
206- GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr ;
207- // Helper function to set the callbacks for inter-tile loop header and
208- // terminator functions when using a custom operation for the loop.
209- SCFTilingOptions &
210- setCustomLoopGenerationFns (GenerateLoopHeaderFn headerFn,
211- GenerateLoopTerminatorFn terminatorFn) {
212- generateLoopHeaderFn = std::move (headerFn);
213- generateLoopTerminatorFn = std::move (terminatorFn);
214- return *this ;
215- }
216120};
217121
218122// / Transformation information returned after tiling.
0 commit comments