@@ -109,17 +109,110 @@ struct LinearizeVectorizable final
109109 }
110110};
111111
112- // / This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
113- // / on a linearized vector.
114- // / Following,
112+ template <typename TOp>
113+ static bool stridesAllOne (TOp op) {
114+ static_assert (
115+ std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116+ std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117+ " expected vector.extract_strided_slice or vector.insert_strided_slice" );
118+ ArrayAttr strides = op.getStrides ();
119+ return llvm::all_of (
120+ strides, [](auto stride) { return isConstantIntValue (stride, 1 ); });
121+ }
122+
123+ // / Convert an array of attributes into a vector of integers, if possible.
124+ static FailureOr<SmallVector<int64_t >> intsFromArrayAttr (ArrayAttr attrs) {
125+ if (!attrs)
126+ return failure ();
127+ SmallVector<int64_t > ints;
128+ ints.reserve (attrs.size ());
129+ for (auto attr : attrs) {
130+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
131+ ints.push_back (intAttr.getInt ());
132+ } else {
133+ return failure ();
134+ }
135+ }
136+ return ints;
137+ }
138+
139+ // / Consider inserting a vector of shape `small` into a vector of shape `large`,
140+ // / at position `offsets`: this function enumeratates all the indices in `large`
141+ // / that are written to. The enumeration is with row-major ordering.
142+ // /
143+ // / Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
144+ // / positions written to are (1,3) and (1,4), which have linearized indices 8
145+ // / and 9. So [8,9] is returned.
146+ // /
147+ // / The length of the returned vector is equal to the number of elements in
148+ // / the shape `small` (i.e. the product of dimensions of `small`).
149+ SmallVector<int64_t > static getStridedSliceInsertionIndices (
150+ ArrayRef<int64_t > small, ArrayRef<int64_t > large,
151+ ArrayRef<int64_t > offsets) {
152+
153+ // Example of alignment between, `large`, `small` and `offsets`:
154+ // large = 4, 5, 6, 7, 8
155+ // small = 1, 6, 7, 8
156+ // offsets = 2, 3, 0
157+ //
158+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
159+ assert ((large.size () >= small.size ()) &&
160+ " rank of 'large' cannot be lower than rank of 'small'" );
161+ assert ((large.size () >= offsets.size ()) &&
162+ " rank of 'large' cannot be lower than the number of offsets" );
163+ unsigned delta = large.size () - small.size ();
164+ unsigned nOffsets = offsets.size ();
165+ auto getSmall = [&](int64_t i) -> int64_t {
166+ return i >= delta ? small[i - delta] : 1 ;
167+ };
168+ auto getOffset = [&](int64_t i) -> int64_t {
169+ return i < nOffsets ? offsets[i] : 0 ;
170+ };
171+
172+ // Using 2 vectors of indices, at each iteration populate the updated set of
173+ // indices based on the old set of indices, and the size of the small vector
174+ // in the current iteration.
175+ SmallVector<int64_t > indices{0 };
176+ int64_t stride = 1 ;
177+ for (int i = large.size () - 1 ; i >= 0 ; --i) {
178+ int64_t currentSize = indices.size ();
179+ int64_t smallSize = getSmall (i);
180+ int64_t nextSize = currentSize * smallSize;
181+ SmallVector<int64_t > nextIndices (nextSize);
182+ int64_t *base = nextIndices.begin ();
183+ int64_t offset = getOffset (i) * stride;
184+ for (int j = 0 ; j < smallSize; ++j) {
185+ for (int k = 0 ; k < currentSize; ++k) {
186+ base[k] = indices[k] + offset;
187+ }
188+ offset += stride;
189+ base += currentSize;
190+ }
191+ stride *= large[i];
192+ indices = std::move (nextIndices);
193+ }
194+ return indices;
195+ }
196+
197+ // / This pattern converts a vector.extract_strided_slice operation into a
198+ // / vector.shuffle operation that has a rank-1 (linearized) operand and result.
199+ // /
200+ // / For example, the following:
201+ // /
202+ // / ```
115203// / vector.extract_strided_slice %source
116204// / { offsets = [..], strides = [..], sizes = [..] }
205+ // / ```
206+ // /
117207// / is converted to :
208+ // / ```
118209// / %source_1d = vector.shape_cast %source
119- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
120- // / %out_nd = vector.shape_cast %out_1d
121- // / `shuffle_indices_1d` is computed using the offsets and sizes of the
122- // / extraction.
210+ // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
211+ // / %out_nd = vector.shape_cast %out_1d
212+ // / ```
213+ // /
214+ // / `shuffle_indices_1d` is computed using the offsets and sizes of the original
215+ // / vector.extract_strided_slice operation.
123216struct LinearizeVectorExtractStridedSlice final
124217 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
125218 using OpConversionPattern::OpConversionPattern;
@@ -129,88 +222,116 @@ struct LinearizeVectorExtractStridedSlice final
129222 : OpConversionPattern(typeConverter, context, benefit) {}
130223
131224 LogicalResult
132- matchAndRewrite (vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
225+ matchAndRewrite (vector::ExtractStridedSliceOp extractStridedSliceOp,
226+ OpAdaptor adaptor,
133227 ConversionPatternRewriter &rewriter) const override {
134- VectorType dstType =
135- getTypeConverter ()->convertType <VectorType>(extractOp.getType ());
136- assert (dstType && " vector type destination expected." );
137- if (extractOp.getVector ().getType ().isScalable () || dstType.isScalable ())
138- return rewriter.notifyMatchFailure (extractOp,
139- " scalable vectors are not supported." );
140228
141- ArrayAttr offsets = extractOp.getOffsets ();
142- ArrayAttr sizes = extractOp.getSizes ();
143- ArrayAttr strides = extractOp.getStrides ();
144- if (!isConstantIntValue (strides[0 ], 1 ))
229+ VectorType flatOutputType = getTypeConverter ()->convertType <VectorType>(
230+ extractStridedSliceOp.getType ());
231+ assert (flatOutputType && " vector type expected" );
232+
233+ // Expect a legalization failure if the strides are not all 1 (if ever the
234+ // verifier for extract_strided_slice allows non-1 strides).
235+ if (!stridesAllOne (extractStridedSliceOp)) {
145236 return rewriter.notifyMatchFailure (
146- extractOp, " Strided slice with stride != 1 is not supported." );
147- Value srcVector = adaptor.getVector ();
148- // If kD offsets are specified for nD source vector (n > k), the granularity
149- // of the extraction is greater than 1. In this case last (n-k) dimensions
150- // form the extraction granularity.
151- // Example :
152- // vector.extract_strided_slice %src {
153- // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
154- // vector<4x8x8xf32> to vector<2x2x8xf32>
155- // Here, extraction granularity is 8.
156- int64_t extractGranularitySize = 1 ;
157- int64_t nD = extractOp.getSourceVectorType ().getRank ();
158- int64_t kD = (int64_t )offsets.size ();
159- int64_t k = kD ;
160- while (k < nD) {
161- extractGranularitySize *= extractOp.getSourceVectorType ().getShape ()[k];
162- ++k;
237+ extractStridedSliceOp,
238+ " extract_strided_slice with strides != 1 not supported" );
163239 }
164- // Get total number of extracted slices.
165- int64_t nExtractedSlices = 1 ;
166- for (Attribute size : sizes) {
167- nExtractedSlices *= cast<IntegerAttr>(size).getInt ();
240+
241+ FailureOr<SmallVector<int64_t >> offsets =
242+ intsFromArrayAttr (extractStridedSliceOp.getOffsets ());
243+ if (failed (offsets)) {
244+ return rewriter.notifyMatchFailure (extractStridedSliceOp,
245+ " failed to get integer offsets" );
168246 }
169- // Compute the strides of the source vector considering first k dimensions.
170- llvm::SmallVector<int64_t , 4 > sourceStrides (kD , extractGranularitySize);
171- for (int i = kD - 2 ; i >= 0 ; --i) {
172- sourceStrides[i] = sourceStrides[i + 1 ] *
173- extractOp.getSourceVectorType ().getShape ()[i + 1 ];
247+
248+ ArrayRef<int64_t > inputShape =
249+ extractStridedSliceOp.getSourceVectorType ().getShape ();
250+
251+ ArrayRef<int64_t > outputShape = extractStridedSliceOp.getType ().getShape ();
252+
253+ SmallVector<int64_t > indices = getStridedSliceInsertionIndices (
254+ outputShape, inputShape, offsets.value ());
255+
256+ Value srcVector = adaptor.getVector ();
257+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
258+ extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
259+ return success ();
260+ }
261+ };
262+
263+ // / This pattern converts a vector.insert_strided_slice operation into a
264+ // / vector.shuffle operation that has rank-1 (linearized) operands and result.
265+ // /
266+ // / For example, the following:
267+ // / ```
268+ // / %0 = vector.insert_strided_slice %to_store, %into
269+ // / {offsets = [1, 0, 0, 0], strides = [1, 1]}
270+ // / : vector<2x2xi8> into vector<2x1x3x2xi8>
271+ // / ```
272+ // /
273+ // / is converted to
274+ // / ```
275+ // / %to_store_1d
276+ // / = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
277+ // / %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
278+ // / %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
279+ // / %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
280+ // / ```
281+ // /
282+ // / where shuffle_indices_1d in this case is
283+ // / [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
284+ // / ^^^^^^^^^^^^^^
285+ // / to_store_1d
286+ // /
287+ struct LinearizeVectorInsertStridedSlice final
288+ : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
289+ using OpConversionPattern::OpConversionPattern;
290+ LinearizeVectorInsertStridedSlice (const TypeConverter &typeConverter,
291+ MLIRContext *context,
292+ PatternBenefit benefit = 1 )
293+ : OpConversionPattern(typeConverter, context, benefit) {}
294+
295+ LogicalResult
296+ matchAndRewrite (vector::InsertStridedSliceOp insertStridedSliceOp,
297+ OpAdaptor adaptor,
298+ ConversionPatternRewriter &rewriter) const override {
299+
300+ // Expect a legalization failure if the strides are not all 1 (if ever the
301+ // verifier for insert_strided_slice allows non-1 strides).
302+ if (!stridesAllOne (insertStridedSliceOp)) {
303+ return rewriter.notifyMatchFailure (
304+ insertStridedSliceOp,
305+ " insert_strided_slice with strides != 1 not supported" );
174306 }
175- // Final shuffle indices has nExtractedSlices * extractGranularitySize
176- // elements.
177- llvm::SmallVector<int64_t , 4 > indices (nExtractedSlices *
178- extractGranularitySize);
179- // Compute the strides of the extracted kD vector.
180- llvm::SmallVector<int64_t , 4 > extractedStrides (kD , 1 );
181- // Compute extractedStrides.
182- for (int i = kD - 2 ; i >= 0 ; --i) {
183- extractedStrides[i] =
184- extractedStrides[i + 1 ] * cast<IntegerAttr>(sizes[i + 1 ]).getInt ();
307+
308+ VectorType inputType = insertStridedSliceOp.getValueToStore ().getType ();
309+ ArrayRef<int64_t > inputShape = inputType.getShape ();
310+
311+ VectorType outputType = insertStridedSliceOp.getType ();
312+ ArrayRef<int64_t > outputShape = outputType.getShape ();
313+ int64_t nOutputElements = outputType.getNumElements ();
314+
315+ FailureOr<SmallVector<int64_t >> offsets =
316+ intsFromArrayAttr (insertStridedSliceOp.getOffsets ());
317+ if (failed (offsets)) {
318+ return rewriter.notifyMatchFailure (insertStridedSliceOp,
319+ " failed to get integer offsets" );
185320 }
186- // Iterate over all extracted slices from 0 to nExtractedSlices - 1
187- // and compute the multi-dimensional index and the corresponding linearized
188- // index within the source vector.
189- for (int64_t i = 0 ; i < nExtractedSlices; ++i) {
190- int64_t index = i;
191- // Compute the corresponding multi-dimensional index.
192- llvm::SmallVector<int64_t , 4 > multiDimIndex (kD , 0 );
193- for (int64_t j = 0 ; j < kD ; ++j) {
194- multiDimIndex[j] = (index / extractedStrides[j]);
195- index -= multiDimIndex[j] * extractedStrides[j];
196- }
197- // Compute the corresponding linearized index in the source vector
198- // i.e. shift the multiDimIndex by the offsets.
199- int64_t linearizedIndex = 0 ;
200- for (int64_t j = 0 ; j < kD ; ++j) {
201- linearizedIndex +=
202- (cast<IntegerAttr>(offsets[j]).getInt () + multiDimIndex[j]) *
203- sourceStrides[j];
204- }
205- // Fill the indices array form linearizedIndex to linearizedIndex +
206- // extractGranularitySize.
207- for (int64_t j = 0 ; j < extractGranularitySize; ++j) {
208- indices[i * extractGranularitySize + j] = linearizedIndex + j;
209- }
321+ SmallVector<int64_t > sliceIndices = getStridedSliceInsertionIndices (
322+ inputShape, outputShape, offsets.value ());
323+
324+ SmallVector<int64_t > indices (nOutputElements);
325+ std::iota (indices.begin (), indices.end (), 0 );
326+ for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices)) {
327+ indices[sliceIndex] = index + nOutputElements;
210328 }
211- // Perform a shuffle to extract the kD vector.
212- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
213- extractOp, dstType, srcVector, srcVector, indices);
329+
330+ Value flatToStore = adaptor.getValueToStore ();
331+ Value flatDest = adaptor.getDest ();
332+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(insertStridedSliceOp,
333+ flatDest.getType (), flatDest,
334+ flatToStore, indices);
214335 return success ();
215336 }
216337};
@@ -296,7 +417,7 @@ struct LinearizeVectorExtract final
296417 // Skip if result is not a vector type
297418 if (!isa<VectorType>(extractOp.getType ()))
298419 return rewriter.notifyMatchFailure (extractOp,
299- " scalar extract is not supported. " );
420+ " scalar extract not supported" );
300421 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
301422 assert (dstTy && " expected 1-D vector type" );
302423
@@ -453,8 +574,8 @@ struct LinearizeVectorSplat final
453574static bool isNotLinearizableBecauseScalable (Operation *op) {
454575
455576 bool unsupported =
456- isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
457- op);
577+ isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
578+ vector::ExtractOp, vector::InsertOp>( op);
458579 if (!unsupported)
459580 return false ;
460581
@@ -539,6 +660,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
539660 const TypeConverter &typeConverter, const ConversionTarget &target,
540661 RewritePatternSet &patterns) {
541662 patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
542- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
543- typeConverter, patterns.getContext ());
663+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
664+ LinearizeVectorInsertStridedSlice>(typeConverter,
665+ patterns.getContext ());
544666}
0 commit comments