1010
1111#include " mlir/IR/AffineMap.h"
1212#include " mlir/IR/Builders.h"
13- #include " mlir/IR/BuiltinTypeInterfaces.h"
14- #include " llvm/ADT/ArrayRef.h"
15- #include " llvm/ADT/SmallVector.h"
16- #include " llvm/Support/LogicalResult.h"
1713
1814#include < numeric>
1915#include < optional>
@@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
3228 return std::nullopt ;
3329}
3430
35- namespace {
36- // / A simple struct to represent ReassociationIndices as an inclusive interval.
37- // / It's designed to be feasibly minimal, so the call sites should manage the
38- // / validity of the range manually.
39- struct ReassociationIndexRange {
40- // / FIXME: Signed type is used for consistency with ReassociationIndices.
41- // / We should consider refactoring all reassociation utilities to use unsigned
42- // / types.
43- int64_t leftIdx = 0 , rightIdx = 0 ;
44-
45- // / Util for manual checks of the range's validity
46- LogicalResult verify () const {
47- return leftIdx >= 0 && (leftIdx <= rightIdx) ? success () : failure ();
48- }
49-
50- // / Checks range's containment within another range. Treats the edges
51- // / non-exclusively.
52- bool isInRange (const ReassociationIndexRange &outerRange) const {
53- return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx ;
54- }
55-
56- unsigned size () const {
57- assert (succeeded (verify ()));
58- return rightIdx - leftIdx + 1 ;
59- }
60- bool containsSingleIndex () const { return size () == 1 ; }
61-
62- // / Collects indices that do not overlap between this and another range.
63- ReassociationIndices
64- getNonOverlappingIndicesWith (ReassociationIndexRange &rhs) const {
65- if (rightIdx < rhs.leftIdx ) {
66- // The intervals do not overlap - concatenate the indices from both.
67- auto jointFullIndices = getFullIndices ();
68- jointFullIndices.append (rhs.getFullIndices ());
69- return jointFullIndices;
70- }
71- ReassociationIndices result;
72- // Handle the chunk left of the overlapping range.
73- int64_t leftStart = std::min (leftIdx, rhs.leftIdx );
74- int64_t leftEnd = std::max (leftIdx, rhs.leftIdx );
75- llvm::append_range (result, llvm::seq (leftStart, leftEnd));
76- // Handle the chunk right of the overlapping range. Symmetrically, we should
77- // skip the edge of the overlap AND include the rightmost index.
78- int64_t rightStart = std::min (rightIdx, rhs.rightIdx ) + 1 ;
79- int64_t rightEnd = std::max (rightIdx, rhs.rightIdx );
80- if (rightStart < rightEnd)
81- llvm::append_range (result, llvm::seq_inclusive (rightStart, rightEnd));
82- return result;
83- }
84-
85- // / Converts the range into ReassociationIndices.
86- ReassociationIndices getFullIndices () const {
87- ReassociationIndices result;
88- for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89- result.push_back (idx);
90- }
91- return result;
92- }
93- };
94- } // namespace
95-
96- // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
97- // / sequence that can be collapsed into a dynamic dimension (at least one must
98- // / be present in the source).
99- // / By default, lazily returns once the first dynamic dimension has been found.
100- // / Setting `matchGreedily` as `true` will also mark all subsequent
101- // / source dimensions for collapsing into the target.
102- static FailureOr<ReassociationIndexRange>
103- findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
104- int64_t sourceStartIdx,
105- bool matchGreedily = false ) {
106- const unsigned numSourceDims = sourceShape.size ();
107- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
108- std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
109-
110- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111- for (; iterationRange.isInRange (sourceShapeAsRange);
112- iterationRange.rightIdx ++) {
113- int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
114- if (sourceSize == ShapedType::kDynamic ) {
115- resultRange = iterationRange;
116- break ;
117- }
118- }
119- if (!resultRange)
120- return failure ();
121- if (matchGreedily)
122- resultRange->rightIdx = sourceShapeAsRange.rightIdx ;
123- return *resultRange;
124- }
31+ std::optional<SmallVector<ReassociationIndices>>
32+ mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
33+ ArrayRef<int64_t > targetShape) {
34+ if (sourceShape.size () <= targetShape.size ())
35+ return std::nullopt ;
36+ unsigned sourceDim = 0 ;
37+ SmallVector<ReassociationIndices> reassociationMap;
38+ reassociationMap.reserve (targetShape.size ());
12539
126- // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
127- // / sequence of static dimensions such that their product matches `targetSize`.
128- // / By default, lazily returns once the product matches the target size. Setting
129- // / `matchGreedily` as `true` will append all neighboring unit dimensions
130- // / (dimensions of 1) to the match.
131- static FailureOr<ReassociationIndexRange>
132- findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
133- int64_t sourceStartIdx, int64_t targetSize,
134- bool matchGreedily = false ) {
135- const unsigned numSourceDims = sourceShape.size ();
136- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
137- std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
138-
139- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
40+ ReassociationIndices currIndices;
14041 int64_t prodOfCollapsedDims = 1 ;
141- while (iterationRange.isInRange (sourceShapeAsRange)) {
142- int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
143- if (sourceSize == ShapedType::kDynamic ) {
144- // Reassociation for a static dim cannot include a dynamic dim. Reset
145- // induction variables to essentially restart the loop from the next
146- // source dimension.
147- prodOfCollapsedDims = 1 ;
148- iterationRange = {iterationRange.rightIdx + 1 ,
149- iterationRange.rightIdx + 1 };
150- continue ;
151- }
152- prodOfCollapsedDims *= sourceSize;
153- // If the target size has been exceeded without matching, we need to shift
154- // the range start right. From the start of the range, roll back the
155- // multiplication until the target size exceeds the product again.
156- while (prodOfCollapsedDims > targetSize &&
157- !iterationRange.containsSingleIndex ()) {
158- int64_t frontSourceSize = sourceShape[iterationRange.leftIdx ];
159- prodOfCollapsedDims /= frontSourceSize;
160- // Shrink the range rightwards
161- iterationRange.leftIdx ++;
162- }
163- // We could've reached the target size with the current dimension,
164- // also as a result of the above shift to right.
165- if (prodOfCollapsedDims == targetSize) {
166- resultRange = iterationRange;
42+ while (sourceDim < sourceShape.size ()) {
43+ unsigned targetDim = reassociationMap.size ();
44+ // If we have mapped all the target dimensions stop and handle the remaining
45+ // tail of size-1 dimensions explicitly.
46+ if (targetDim == targetShape.size ())
16747 break ;
168- }
169- // Increment the iteration range
170- iterationRange.rightIdx ++;
171- }
172- if (!resultRange)
173- return failure ();
174- if (matchGreedily) {
175- // We now want to collect all unit dimensions directly after the target
176- // product match. Advance the iterator to avoid OOB when the product match
177- // happens at the last element.
178- iterationRange.rightIdx ++;
179- while (iterationRange.isInRange (sourceShapeAsRange) &&
180- sourceShape[iterationRange.rightIdx ] == 1 ) {
181- resultRange = iterationRange;
182- iterationRange.rightIdx ++;
183- }
184- }
185- return *resultRange;
186- }
18748
188- // / Attempts to find a valid collapsing reassociation of `sourceShape` into
189- // / `targetShape` through a simple traversal. If successful, an array of source
190- // / index ranges is returned, correspondingly to each dimension in the target
191- // / shape. The resulting indices shall fully cover the `sourceShape` without
192- // / overlaps.
193- // /
194- // / The algorithm is essentially a lazy one, searching for non-greedy matches -
195- // / it will only yield a greedy match for the last target dimension.
196- // / FIXME: The algorithm can only backtrack when it needs to append an offset
197- // / for a static target dimension to the preceding dynamic one (this retains the
198- // / linear complexity). As feasible, consider adding further backtracking
199- // / routines to enable more reassociations, e.g.:
200- // / - ?x2x?x2 into ?x2
201- static FailureOr<SmallVector<ReassociationIndexRange>>
202- findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
203- ArrayRef<int64_t > targetShape) {
204- unsigned numSourceDims = sourceShape.size (),
205- numTargetDims = targetShape.size ();
206- assert (numSourceDims > numTargetDims);
207- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
208-
209- SmallVector<ReassociationIndexRange> reassocRanges;
210- reassocRanges.reserve (numTargetDims);
211- // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212- // cases, e.g.:
213- // - ?x2x3x5 into ?x15
214- std::optional<int64_t > prevTargetSize = std::nullopt ;
215- for (unsigned targetDimIdx = 0 , sourceDimIdx = 0 ;
216- targetDimIdx < numTargetDims; ++targetDimIdx) {
217- int64_t targetSize = targetShape[targetDimIdx];
218- // Simply check if there are any subsequent target dimensions left - if not,
219- // the match must be made greedily.
220- bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1 ;
221- FailureOr<ReassociationIndexRange> sourceRange;
222- if (targetSize == ShapedType::kDynamic ) {
223- sourceRange = findReassociationRangeForDynamicDim (
224- sourceShape, sourceDimIdx, shouldMatchGreedily);
225- } else {
226- sourceRange = findReassociationRangeForSize (
227- sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
49+ int64_t currTargetShape = targetShape[targetDim];
50+ while (sourceDim < (sourceShape.size () - 1 ) &&
51+ sourceShape[sourceDim] != ShapedType::kDynamic &&
52+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53+ prodOfCollapsedDims *= sourceShape[sourceDim];
54+ currIndices.push_back (sourceDim++);
22855 }
22956
230- // Run sanity checks on the returned index range.
231- if (failed (sourceRange) || failed (sourceRange->verify ()) ||
232- !sourceRange->isInRange (sourceShapeAsRange))
233- return failure ();
234- if (sourceRange->leftIdx > sourceDimIdx) {
235- // If some source dimensions had to be skipped in order to find a match,
236- // they must be collapsed into the directly preceding dynamic dimension.
237- if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic )
238- return failure ();
239- reassocRanges.back ().rightIdx = sourceRange->leftIdx - 1 ;
240- }
241-
242- // Store the gathered information as required for the next iteration.
243- prevTargetSize = targetSize;
244- sourceDimIdx = sourceRange->rightIdx + 1 ;
245- reassocRanges.push_back (*sourceRange);
57+ // If the current expanded dimension is dynamic, then the collapsed
58+ // dimensions should also be dynamic and product of all previous unprocessed
59+ // dimensions of the expanded shape should be 1.
60+ if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61+ (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1 ))
62+ return std::nullopt ;
63+
64+ // If the collapsed dim is dynamic, the current expanded dim should also
65+ // be dynamic.
66+ if (currTargetShape == ShapedType::kDynamic &&
67+ sourceShape[sourceDim] != ShapedType::kDynamic )
68+ return std::nullopt ;
69+
70+ // For static shapes, if the product of dimensions of the expanded shape
71+ // should match the collapsed dimension shape.
72+ if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73+ return std::nullopt ;
74+
75+ currIndices.push_back (sourceDim++);
76+ reassociationMap.emplace_back (ReassociationIndices{});
77+ std::swap (reassociationMap.back (), currIndices);
78+ prodOfCollapsedDims = 1 ;
24679 }
247- // Fail if the source shape wasn't a full match for the target shape. We only
248- // need to check the last recorded index - any other gaps should have been
249- // mended by the main loop.
250- if (reassocRanges.back ().rightIdx < sourceShapeAsRange.rightIdx )
251- return failure ();
252- return reassocRanges;
253- }
254-
255- // / A variant of `findReassociationRangesForCollapse(...)` that can also scan
256- // / the shapes right-to-left.
257- static FailureOr<SmallVector<ReassociationIndexRange>>
258- findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
259- ArrayRef<int64_t > targetShape,
260- bool iterateRightToLeft) {
261- if (!iterateRightToLeft)
262- return findReassociationRangesForCollapse (sourceShape, targetShape);
263- // NB: To iterate right-to-left, we currently reverse the shapes and then
264- // reverse the result back. The reversed shapes must not be temporary, as
265- // we're passing through an ArrayRef.
266- // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267- // this approach is chosen for readability of the main implementation.
268- std::vector<int64_t > sourceToReverse = sourceShape.vec (),
269- targetToReverse = targetShape.vec ();
270- std::reverse (sourceToReverse.begin (), sourceToReverse.end ());
271- std::reverse (targetToReverse.begin (), targetToReverse.end ());
272- auto invertedRanges =
273- findReassociationRangesForCollapse (sourceToReverse, targetToReverse);
274- if (failed (invertedRanges))
275- return failure ();
276- SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277- unsigned numSourceDims = sourceShape.size ();
278- // We have received the ranges for inverted shapes. Now we have to invert
279- // the ranges back to correspond with the original source shape.
280- for (auto &range : rangesToInvert) {
281- int64_t invLeftIdx = range.leftIdx , invRightIdx = range.rightIdx ;
282- range.leftIdx = numSourceDims - 1 - invRightIdx;
283- range.rightIdx = numSourceDims - 1 - invLeftIdx;
284- }
285- // Also invert the ordering of the ranges to correspond with the original
286- // target shape.
287- std::reverse (rangesToInvert.begin (), rangesToInvert.end ());
288- return rangesToInvert;
289- }
290-
291- std::optional<SmallVector<ReassociationIndices>>
292- mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
293- ArrayRef<int64_t > targetShape) {
294- unsigned numSourceDims = sourceShape.size (),
295- numTargetDims = targetShape.size ();
296- // We're supposed to search for a collapsing reassociation. If the sizes
297- // match, there's no actual collapsing taking place - it's either a no-op or a
298- // `tensor.reshape`-style reassociation (that would be beyond the scope of
299- // this utility).
300- if (numSourceDims <= numTargetDims)
301- return std::nullopt ;
302- // Early handling for scalar target types.
303- if (numTargetDims == 0 ) {
304- ReassociationIndices allSourceIndices;
305- allSourceIndices.reserve (numSourceDims);
306- for (unsigned sourceDimIdx = 0 ; sourceDimIdx < numSourceDims;
307- ++sourceDimIdx) {
308- int64_t sourceSize = sourceShape[sourceDimIdx];
309- // All source dimensions must be unit or dynamic.
310- if (sourceSize != 1 && sourceSize != ShapedType::kDynamic )
311- return std::nullopt ;
312- allSourceIndices.push_back (sourceDimIdx);
313- }
314- return SmallVector<ReassociationIndices>{allSourceIndices};
315- }
316-
317- // Collect source ranges by iterating over the target shape left-to-right.
318- FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319- findReassociationRangesForCollapse (sourceShape, targetShape);
320- if (failed (maybeForwardRanges))
321- return std::nullopt ;
322- auto &ranges = *maybeForwardRanges;
323- // Now do the same in reverse. We need to get another valid reassociation
324- // through some other strategy, and then compare the results in order to
325- // disambiguate mixed subshapes, such as:
326- // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327- // This leads us to lose some of the reassociation opportunities that can only
328- // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329- // backtracking, the algorithm will fail right-to-left. However, this is the
330- // best way to preserve correctness.
331- FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332- findReassociationRangesForCollapse (sourceShape, targetShape,
333- /* iterateRightToLeft=*/ true );
334- if (failed (maybeReverseRanges))
335- return std::nullopt ;
336- auto &reverseRanges = *maybeReverseRanges;
337-
338- if (ranges.size () != numTargetDims || reverseRanges.size () != numTargetDims)
80+ // All the dimensions in the target must have been processed.
81+ if (reassociationMap.size () != targetShape.size ())
33982 return std::nullopt ;
340- // Now we can check for ambiguity of each target dimension's reassociation. If
341- // successful, we put the full indices into our result map for the target
342- // shape.
343- SmallVector<ReassociationIndices> reassociationMap (numTargetDims);
344- for (unsigned targetDimIdx = 0 ; targetDimIdx < numTargetDims;
345- ++targetDimIdx) {
346- ReassociationIndexRange &range = ranges[targetDimIdx];
347- ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348- // Get non-overlapping indices between the ranges
349- ReassociationIndices nonMatchingIndices =
350- range.getNonOverlappingIndicesWith (reverseRange);
351- // Unit dimensions can be collapsed wherever - this is the only ambiguity
352- // that we allow.
353- for (int64_t sourceDimIdx : nonMatchingIndices) {
354- if (sourceShape[sourceDimIdx] != 1 )
355- return std::nullopt ;
356- }
357- reassociationMap[targetDimIdx] = range.getFullIndices ();
83+ // Process any remaining entries in the source shape. They all need to be
84+ // 1 or dynamic.
85+ for (; sourceDim < sourceShape.size (); sourceDim++) {
86+ if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87+ sourceShape[sourceDim] != 1 )
88+ return std::nullopt ;
89+ // The map is empty when the target type is a scalar.
90+ if (!reassociationMap.empty ())
91+ reassociationMap.back ().push_back (sourceDim);
35892 }
35993 return reassociationMap;
36094}
0 commit comments