@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
33
33
34
34
// / Options to use to control tiling.
35
35
struct SCFTilingOptions {
36
+ // / Specify which loop construct to use for tile and fuse.
37
+ enum class LoopType { ForOp, ForallOp, CustomOp };
38
+ LoopType loopType = LoopType::ForOp;
39
+ SCFTilingOptions &setLoopType (LoopType type) {
40
+ loopType = type;
41
+ return *this ;
42
+ }
43
+
36
44
// / Computation function that returns the tile sizes to use for each loop.
37
45
// / Returning a tile size of zero implies no tiling for that loop. If the
38
46
// / size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
50
58
// / proper interaction with folding.
51
59
SCFTilingOptions &setTileSizes (ArrayRef<OpFoldResult> tileSizes);
52
60
61
+ // / The interchange vector to reorder the tiled loops.
62
+ SmallVector<int64_t > interchangeVector = {};
63
+ SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
64
+ interchangeVector = llvm::to_vector (interchange);
65
+ return *this ;
66
+ }
67
+
68
+ // -------------------------------------------------------------------------//
69
+ // Options related to tiling using `scf.forall`.
70
+ // -------------------------------------------------------------------------//
71
+
53
72
// / Computation function that returns the number of threads to use for
54
73
// / each loop. Returning a num threads of zero implies no tiling for that
55
74
// / loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
70
89
// / function that computes num threads at the point they are needed.
71
90
SCFTilingOptions &setNumThreads (ArrayRef<OpFoldResult> numThreads);
72
91
73
- // / The interchange vector to reorder the tiled loops.
74
- SmallVector<int64_t > interchangeVector = {};
75
- SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
76
- interchangeVector = llvm::to_vector (interchange);
77
- return *this ;
78
- }
79
-
80
- // / Specify which loop construct to use for tile and fuse.
81
- enum class LoopType { ForOp, ForallOp };
82
- LoopType loopType = LoopType::ForOp;
83
- SCFTilingOptions &setLoopType (LoopType type) {
84
- loopType = type;
85
- return *this ;
86
- }
87
-
88
92
// / Specify mapping of loops to devices. This is only respected when the loop
89
93
// / constructs support such a mapping (like `scf.forall`). Will be ignored
90
94
// / when using loop constructs that dont support such a mapping (like
@@ -117,6 +121,96 @@ struct SCFTilingOptions {
117
121
reductionDims.insert (dims.begin (), dims.end ());
118
122
return *this ;
119
123
}
124
+
125
+ // -------------------------------------------------------------------------//
126
+ // Options related to tiling using custom loop.
127
+ // -------------------------------------------------------------------------//
128
+
129
+ // For generating the inter-tile loops using a custom loop, two callback
130
+ // functions are needed
131
+ // 1. That generates the "loop header", i.e. the loop that iterates over the
132
+ // different tiles.
133
+ // 2. That generates the loop terminator
134
+ //
135
+ // For `scf.forall` case the call back to generate loop header would generate
136
+ //
137
+ // ```mlir
138
+ // scf.forall (...) = ... {
139
+ // ..
140
+ // }
141
+ // ```
142
+ //
143
+ // and the call back to generate the loop terminator would generate the
144
+ // `scf.in_parallel` region
145
+ //
146
+ // ```mlir
147
+ // scf.forall (...) = ... {
148
+ // scf.in_parallel {
149
+ // tensor.parallel_insert_slice ...
150
+ // }
151
+ // }
152
+ // ```
153
+ //
154
+
155
+ // Information that is to be returned by loop header callback needed for the
156
+ // rest of the tiled codegeneration.
157
+ // - `loops`: The generated loops
158
+ // - `tileOffset`: The values that represent the offset of the iteration space
159
+ // tile.
160
+ // - `tileSizes` : The values that represent the size of the iteration space
161
+ // tile.
162
+ // - `destinationTensors` : The tensors to use as destinations during tiling.
163
+ struct CustomLoopHeaderInfo {
164
+ SmallVector<LoopLikeOpInterface> loops;
165
+ SmallVector<OpFoldResult> tileOffset;
166
+ SmallVector<OpFoldResult> tileSizes;
167
+ SmallVector<Value> destinationTensors;
168
+ };
169
+
170
+ // Type of the callback function that generates the loop headers.
171
+ // - `loopRanges` : Values that represent the full size of the iteration space
172
+ // being tiled.
173
+ // - `givenTileSizes` : The tile sizes that are to be used to tile the
174
+ // iteration space.
175
+ // - `destinationTensors` : The tensors to use as destinations for the results
176
+ // of the tiled loop for loops that implement
177
+ // `DestinationStyleOpInterface`.
178
+ // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
179
+ // that this function sets the insertion point of `rewriter` to the program
180
+ // point where the intra-tile loop computation is to be generated.
181
+ using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
182
+ RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
183
+ ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
184
+
185
+ // Type of the callback function that generates the loop terminator.
186
+ // - `tiledResults` : Tiles of the result computed for the iteration space
187
+ // tile.
188
+ // - `resultOffsets` : For each of the `tiledResults`, the offset at which
189
+ // the result tile is to be "inserted" back into the
190
+ // destination tensor.
191
+ // - `resultSizes` : For each of the `tiledResults`, the size of the result
192
+ // tile that is to be "inserted" back into the destination
193
+ // tensor.
194
+ // Returns the `CustomLoopHeaderInfo` object (described above)
195
+ using GenerateLoopTerminatorFn = std::function<LogicalResult(
196
+ RewriterBase &rewriter, Location loc, ValueRange tiledResults,
197
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
198
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
199
+ ValueRange destinationTensors)>;
200
+
201
+ // Callback function to generate the inter-tile loop header.
202
+ GenerateLoopHeaderFn generateLoopHeaderFn = nullptr ;
203
+ // Callback function to generate the inter-tile loop terminator.
204
+ GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr ;
205
+ // Helper function to set the callbacks for inter-tile loop header and
206
+ // terminator functions when using a custom operation for the loop.
207
+ SCFTilingOptions &
208
+ setCustomLoopGenerationFns (GenerateLoopHeaderFn headerFn,
209
+ GenerateLoopTerminatorFn terminatorFn) {
210
+ generateLoopHeaderFn = std::move (headerFn);
211
+ generateLoopTerminatorFn = std::move (terminatorFn);
212
+ return *this ;
213
+ }
120
214
};
121
215
122
216
// / Transformation information returned after tiling.
0 commit comments