@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
109109 }
110110};
111111
112- template <typename TOp>
113- static bool stridesAllOne (TOp op) {
114- static_assert (
115- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117- " expected vector.extract_strided_slice or vector.insert_strided_slice" );
118- ArrayAttr strides = op.getStrides ();
119- return llvm::all_of (strides, isOneInteger);
120- }
121-
122- // / Convert an array of attributes into a vector of integers, if possible.
123- static FailureOr<SmallVector<int64_t >> intsFromArrayAttr (ArrayAttr attrs) {
124- if (!attrs)
125- return failure ();
126- SmallVector<int64_t > ints;
127- ints.reserve (attrs.size ());
128- for (auto attr : attrs) {
129- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130- ints.push_back (intAttr.getInt ());
131- } else {
132- return failure ();
133- }
134- }
135- return ints;
136- }
137-
138- // / Consider inserting a vector of shape `small` into a vector of shape `large`,
139- // / at position `offsets`: this function enumeratates all the indices in `large`
140- // / that are written to. The enumeration is with row-major ordering.
141- // /
142- // / Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143- // / positions written to are (1,3) and (1,4), which have linearized indices 8
144- // / and 9. So [8,9] is returned.
145- // /
146- // / The length of the returned vector is equal to the number of elements in
147- // / the shape `small` (i.e. the product of dimensions of `small`).
148- SmallVector<int64_t > static getStridedSliceInsertionIndices (
149- ArrayRef<int64_t > small, ArrayRef<int64_t > large,
150- ArrayRef<int64_t > offsets) {
151-
152- // Example of alignment between, `large`, `small` and `offsets`:
153- // large = 4, 5, 6, 7, 8
154- // small = 1, 6, 7, 8
155- // offsets = 2, 3, 0
156- //
157- // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158- assert ((large.size () >= small.size ()) &&
159- " rank of 'large' cannot be lower than rank of 'small'" );
160- assert ((large.size () >= offsets.size ()) &&
161- " rank of 'large' cannot be lower than the number of offsets" );
162- unsigned delta = large.size () - small.size ();
163- unsigned nOffsets = offsets.size ();
164- auto getSmall = [&](int64_t i) -> int64_t {
165- return i >= delta ? small[i - delta] : 1 ;
166- };
167- auto getOffset = [&](int64_t i) -> int64_t {
168- return i < nOffsets ? offsets[i] : 0 ;
169- };
170-
171- // Using 2 vectors of indices, at each iteration populate the updated set of
172- // indices based on the old set of indices, and the size of the small vector
173- // in the current iteration.
174- SmallVector<int64_t > indices{0 };
175- int64_t stride = 1 ;
176- for (int i = large.size () - 1 ; i >= 0 ; --i) {
177- int64_t currentSize = indices.size ();
178- int64_t smallSize = getSmall (i);
179- int64_t nextSize = currentSize * smallSize;
180- SmallVector<int64_t > nextIndices (nextSize);
181- int64_t *base = nextIndices.begin ();
182- int64_t offset = getOffset (i) * stride;
183- for (int j = 0 ; j < smallSize; ++j) {
184- for (int k = 0 ; k < currentSize; ++k) {
185- base[k] = indices[k] + offset;
186- }
187- offset += stride;
188- base += currentSize;
189- }
190- stride *= large[i];
191- indices = std::move (nextIndices);
192- }
193- return indices;
194- }
195-
196112// / This pattern converts a vector.extract_strided_slice operation into a
197113// / vector.shuffle operation that has a rank-1 (linearized) operand and result.
198114// /
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
231147
232148 // Expect a legalization failure if the strides are not all 1 (if ever the
233149 // verifier for extract_strided_slice allows non-1 strides).
234- if (! stridesAllOne ( extractStridedSliceOp)) {
150+ if (extractStridedSliceOp. hasNonUnitStrides ( )) {
235151 return rewriter.notifyMatchFailure (
236152 extractStridedSliceOp,
237153 " extract_strided_slice with strides != 1 not supported" );
238154 }
239155
240- FailureOr<SmallVector<int64_t >> offsets =
241- intsFromArrayAttr ( extractStridedSliceOp.getOffsets () );
242- if (failed (offsets )) {
156+ FailureOr<SmallVector<int64_t >> indices =
157+ extractStridedSliceOp.getLinearIndices ( );
158+ if (failed (indices )) {
243159 return rewriter.notifyMatchFailure (extractStridedSliceOp,
244- " failed to get integer offsets " );
160+ " failed to get indices " );
245161 }
246162
247- ArrayRef<int64_t > inputShape =
248- extractStridedSliceOp.getSourceVectorType ().getShape ();
249-
250- ArrayRef<int64_t > outputShape = extractStridedSliceOp.getType ().getShape ();
251-
252- SmallVector<int64_t > indices = getStridedSliceInsertionIndices (
253- outputShape, inputShape, offsets.value ());
254-
255163 Value srcVector = adaptor.getVector ();
256- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
257- extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
164+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(extractStridedSliceOp,
165+ flatOutputType, srcVector,
166+ srcVector, indices.value ());
258167 return success ();
259168 }
260169};
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
298207
299208 // Expect a legalization failure if the strides are not all 1 (if ever the
300209 // verifier for insert_strided_slice allows non-1 strides).
301- if (! stridesAllOne ( insertStridedSliceOp)) {
210+ if (insertStridedSliceOp. hasNonUnitStrides ( )) {
302211 return rewriter.notifyMatchFailure (
303212 insertStridedSliceOp,
304213 " insert_strided_slice with strides != 1 not supported" );
305214 }
306215
307- VectorType inputType = insertStridedSliceOp.getValueToStore ().getType ();
308- ArrayRef<int64_t > inputShape = inputType.getShape ();
309-
310216 VectorType outputType = insertStridedSliceOp.getType ();
311- ArrayRef<int64_t > outputShape = outputType.getShape ();
312217 int64_t nOutputElements = outputType.getNumElements ();
313218
314- FailureOr<SmallVector<int64_t >> offsets =
315- intsFromArrayAttr ( insertStridedSliceOp.getOffsets () );
316- if (failed (offsets)) {
219+ FailureOr<SmallVector<int64_t >> sliceIndices =
220+ insertStridedSliceOp.getLinearIndices ( );
221+ if (failed (sliceIndices))
317222 return rewriter.notifyMatchFailure (insertStridedSliceOp,
318- " failed to get integer offsets" );
319- }
320- SmallVector<int64_t > sliceIndices = getStridedSliceInsertionIndices (
321- inputShape, outputShape, offsets.value ());
223+ " failed to get indices" );
322224
323225 SmallVector<int64_t > indices (nOutputElements);
324226 std::iota (indices.begin (), indices.end (), 0 );
325- for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices)) {
227+ for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices. value () )) {
326228 indices[sliceIndex] = index + nOutputElements;
327229 }
328230
0 commit comments