@@ -33,14 +33,6 @@ using SCFTileSizeComputationFunction =
33
33
34
34
// / Options to use to control tiling.
35
35
struct SCFTilingOptions {
36
- // / Specify which loop construct to use for tile and fuse.
37
- enum class LoopType { ForOp, ForallOp, CustomOp };
38
- LoopType loopType = LoopType::ForOp;
39
- SCFTilingOptions &setLoopType (LoopType type) {
40
- loopType = type;
41
- return *this ;
42
- }
43
-
44
36
// / Computation function that returns the tile sizes to use for each loop.
45
37
// / Returning a tile size of zero implies no tiling for that loop. If the
46
38
// / size of the returned vector is smaller than the number of loops, the inner
@@ -58,17 +50,6 @@ struct SCFTilingOptions {
58
50
// / proper interaction with folding.
59
51
SCFTilingOptions &setTileSizes (ArrayRef<OpFoldResult> tileSizes);
60
52
61
- // / The interchange vector to reorder the tiled loops.
62
- SmallVector<int64_t > interchangeVector = {};
63
- SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
64
- interchangeVector = llvm::to_vector (interchange);
65
- return *this ;
66
- }
67
-
68
- // -------------------------------------------------------------------------//
69
- // Options related to tiling using `scf.forall`.
70
- // -------------------------------------------------------------------------//
71
-
72
53
// / Computation function that returns the number of threads to use for
73
54
// / each loop. Returning a num threads of zero implies no tiling for that
74
55
// / loop. If the size of the returned vector is smaller than the number of
@@ -89,6 +70,21 @@ struct SCFTilingOptions {
89
70
// / function that computes num threads at the point they are needed.
90
71
SCFTilingOptions &setNumThreads (ArrayRef<OpFoldResult> numThreads);
91
72
73
+ // / The interchange vector to reorder the tiled loops.
74
+ SmallVector<int64_t > interchangeVector = {};
75
+ SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
76
+ interchangeVector = llvm::to_vector (interchange);
77
+ return *this ;
78
+ }
79
+
80
+ // / Specify which loop construct to use for tile and fuse.
81
+ enum class LoopType { ForOp, ForallOp };
82
+ LoopType loopType = LoopType::ForOp;
83
+ SCFTilingOptions &setLoopType (LoopType type) {
84
+ loopType = type;
85
+ return *this ;
86
+ }
87
+
92
88
// / Specify mapping of loops to devices. This is only respected when the loop
93
89
// / constructs support such a mapping (like `scf.forall`). Will be ignored
94
90
// / when using loop constructs that dont support such a mapping (like
@@ -121,98 +117,6 @@ struct SCFTilingOptions {
121
117
reductionDims.insert (dims.begin (), dims.end ());
122
118
return *this ;
123
119
}
124
-
125
- // -------------------------------------------------------------------------//
126
- // Options related to tiling using custom loop.
127
- // -------------------------------------------------------------------------//
128
-
129
- // For generating the inter-tile loops using a custom loop, two callback
130
- // functions are needed
131
- // 1. That generates the "loop header", i.e. the loop that iterates over the
132
- // different tiles.
133
- // 2. That generates the loop terminator
134
- //
135
- // For `scf.forall` case the call back to generate loop header would generate
136
- //
137
- // ```mlir
138
- // scf.forall (...) = ... {
139
- // ..
140
- // }
141
- // ```
142
- //
143
- // and the call back to generate the loop terminator would generate the
144
- // `scf.in_parallel` region
145
- //
146
- // ```mlir
147
- // scf.forall (...) = ... {
148
- // scf.in_parallel {
149
- // tensor.parallel_insert_slice ...
150
- // }
151
- // }
152
- // ```
153
- //
154
-
155
- // Information that is to be returned by the callback to generate the loop
156
- // header needed for the rest of the tiled codegeneration.
157
- // - `loops`: The generated loops
158
- // - `tileOffset`: The values that represent the offset of the iteration space
159
- // tile
160
- // - `tileSizes` : The values that represent the size of the iteration space
161
- // tile.
162
- // - `destinationTensors` : The tensors to use as destinations during tiling.
163
- struct CustomLoopHeaderInfo {
164
- SmallVector<LoopLikeOpInterface> loops;
165
- SmallVector<OpFoldResult> tileOffset;
166
- SmallVector<OpFoldResult> tileSizes;
167
- SmallVector<Value> destinationTensors;
168
- };
169
-
170
- // Type of the callback function that generates the loop headers.
171
- // - `loopRanges` : Values that represent the full size of the iteration space
172
- // being tiled.
173
- // - `giveTileSizes` : The tile sizes that are to be used to tile the
174
- // iteration
175
- // space.
176
- // - `destinationTensors` : The tensors to use as destinations for the results
177
- // of the tiled loop for loops that implement
178
- // `DestinationStyleOpInterface`.
179
- // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
180
- // that this function sets the insertion point of `rewriter` to the program
181
- // point where the intra-tile loop computation is to be generated.
182
- using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
183
- RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
184
- ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
185
-
186
- // Type of the callback function that generates the loop terminator.
187
- // - `tiledResults` : Tiles of the result computed for the iteration space
188
- // tile
189
- // - `resultOffsets` : For each of the `tiledResults`, the offset at which
190
- // the result tile is to be "inserted" back into the
191
- // destination tensor.
192
- // - `resultSizes` : For each of the `tiledResults`, the size of the result
193
- // tile
194
- // that is to be "inserted" back into the destination
195
- // tensor.
196
- // Returns the `CustomLoopHeaderInfo` object (described above)
197
- using GenerateLoopTerminatorFn = std::function<LogicalResult(
198
- RewriterBase &rewriter, Location loc, ValueRange tiledResults,
199
- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200
- ArrayRef<SmallVector<OpFoldResult>> resultSizes,
201
- ValueRange destinationTensors)>;
202
-
203
- // Callback function to generate the inter-tile loop header.
204
- GenerateLoopHeaderFn generateLoopHeaderFn = nullptr ;
205
- // Callback function to generate the inter-tile loop terminator.
206
- GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr ;
207
- // Helper function to set the callbacks for inter-tile loop header and
208
- // terminator functions when using a custom operation for the loop.
209
- SCFTilingOptions &
210
- setCustomLoopGenerationFns (GenerateLoopHeaderFn headerFn,
211
- GenerateLoopTerminatorFn terminatorFn) {
212
- generateLoopHeaderFn = std::move (headerFn);
213
- generateLoopTerminatorFn = std::move (terminatorFn);
214
- return *this ;
215
- }
216
120
};
217
121
218
122
// / Transformation information returned after tiling.
0 commit comments