@@ -85,36 +85,6 @@ struct SCFTilingOptions {
8585 return *this ;
8686 }
8787
88- // / Specify how reduction dimensions should be tiled.
89- // /
90- // / Tiling can be thought of as splitting a dimension into 2 and materializing
91- // / the outer dimension as a loop:
92- // /
93- // / op[original] -> op[original / x, x] -> loop[original] { op[x] }
94- // /
95- // / For parallel dimensions, the split can only happen in one way, with both
96- // / dimensions being parallel. For reduction dimensions however, there is a
97- // / choice in how we split the reduction dimension. This enum exposes this
98- // / choice.
99- enum class ReductionTilingStrategy {
100- // [reduction] -> [reduction1, reduction2]
101- // -> loop[reduction1] { [reduction2] }
102- FullReduction,
103- // [reduction] -> [reduction1, parallel2]
104- // -> loop[reduction1] { [parallel2] }; merge[reduction1]
105- PartialReductionOuterReduction,
106- // [reduction] -> [parallel1, reduction2]
107- // -> loop[parallel1] { [reduction2] }; merge[parallel1]
108- PartialReductionOuterParallel
109- };
110- ReductionTilingStrategy reductionStrategy =
111- ReductionTilingStrategy::FullReduction;
112- SCFTilingOptions &
113- setReductionTilingStrategy (ReductionTilingStrategy strategy) {
114- reductionStrategy = strategy;
115- return *this ;
116- }
117-
11888 // / Specify mapping of loops to devices. This is only respected when the loop
11989 // / constructs support such a mapping (like `scf.forall`). Will be ignored
12090 // / when using loop constructs that dont support such a mapping (like
@@ -132,16 +102,11 @@ struct SCFTilingResult {
132102 // / matter except the last op. The replacements are expected to be the results
133103 // / of the last op.
134104 SmallVector<Operation *> tiledOps;
135- // / The initial destination values passed to the tiled operations.
136- SmallVector<Value> initialValues;
137105 // / The `scf.for` operations that iterate over the tiles.
138106 SmallVector<LoopLikeOpInterface> loops;
139- // / The result generated by the loop nest in tiling, may hold partial results,
140- // / which need to be merged to match the computation of the untiled operation.
141- // / `mergeResult` contains the operations used to perform this merge from
142- // / partial results and the values that can be used as replacements of
143- // / the untiled operation.
144- MergeResult mergeResult;
107+ // / Values to use as replacements for the untiled op. Is the same size as the
108+ // / number of results of the untiled op.
109+ SmallVector<Value> replacements;
145110 // / Slices generated after tiling that can be used for fusing with the tiled
146111 // / producer.
147112 SmallVector<Operation *> generatedSlices;
@@ -335,6 +300,20 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
335300FailureOr<SmallVector<scf::ForOp>>
336301lowerToLoopsUsingSCFForOp (RewriterBase &rewriter, TilingInterface op);
337302
303+ // / Transformation information returned after reduction tiling.
304+ struct SCFReductionTilingResult {
305+ // / The partial reduction tiled op generated.
306+ SmallVector<Operation *> parallelTiledOps;
307+ // / The final reduction operation merging all the partial reductions.
308+ SmallVector<Operation *> mergeOps;
309+ // / Initial values used for reduction.
310+ SmallVector<Value> initialValues;
311+ // / The loop operations that iterate over the tiles.
312+ SmallVector<LoopLikeOpInterface> loops;
313+ // / The replacements to use for the results of the tiled operation.
314+ SmallVector<Value> replacements;
315+ };
316+
338317// / Method to tile a reduction and generate a parallel op within a serial loop.
339318// / Each of the partial reductions are calculated in parallel. Then after the
340319// / loop all the partial reduction are merged into a final reduction.
@@ -359,7 +338,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
359338// / %6 = linalg.generic %1 ["parallel", "reduction"]
360339// / : tensor<7x4xf32> -> tensor<7xf32>
361340// / ```
362- FailureOr<scf::SCFTilingResult >
341+ FailureOr<scf::SCFReductionTilingResult >
363342tileReductionUsingScf (RewriterBase &b, PartialReductionOpInterface op,
364343 ArrayRef<OpFoldResult> tileSize);
365344
0 commit comments