@@ -109,17 +109,103 @@ struct LinearizeVectorizable final
109109 }
110110};
111111
112- // / This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
113- // / on a linearized vector.
114- // / Following,
112+ template <typename TOp>
113+ static bool stridesAllOne (TOp op) {
114+ static_assert (
115+ std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116+ std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117+ " expected vector.extract_strided_slice or vector.insert_strided_slice" );
118+ ArrayAttr strides = op.getStrides ();
119+ return llvm::all_of (
120+ strides, [](auto stride) { return isConstantIntValue (stride, 1 ); });
121+ }
122+
123+ // / Convert an array of attributes into a vector of integers, if possible.
124+ static FailureOr<SmallVector<int64_t >> intsFromArrayAttr (ArrayAttr attrs) {
125+ if (!attrs)
126+ return failure ();
127+ SmallVector<int64_t > ints;
128+ ints.reserve (attrs.size ());
129+ for (auto attr : attrs) {
130+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
131+ ints.push_back (intAttr.getInt ());
132+ } else {
133+ return failure ();
134+ }
135+ }
136+ return ints;
137+ }
138+
139+ // / Consider inserting a vector of shape `small` into a vector of shape `large`,
140+ // / at position `offsets`: this function enumeratates all the indices in `large`
141+ // / that are written to. The enumeration is with row-major ordering.
142+ // /
143+ // / Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
144+ // / positions written to are (1,3) and (1,4), which have linearized indices 8
145+ // / and 9. So [8,9] is returned.
146+ SmallVector<int64_t > static getFlattenedStridedSliceIndices (
147+ ArrayRef<int64_t > small, ArrayRef<int64_t > large,
148+ ArrayRef<int64_t > offsets) {
149+
150+ // Example of alignment between, `large`, `small` and `offsets`:
151+ // large = 4, 5, 6, 7, 8
152+ // small = 1, 6, 7, 8
153+ // offsets = 2, 3, 0
154+ //
155+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
156+ assert (large.size () >= small.size ());
157+ assert (large.size () >= offsets.size ());
158+ unsigned delta = large.size () - small.size ();
159+ unsigned nOffsets = offsets.size ();
160+ auto getSmall = [&](int64_t i) { return i >= delta ? small[i - delta] : 1 ; };
161+ auto getOffset = [&](int64_t i) { return i < nOffsets ? offsets[i] : 0 ; };
162+
163+ // Using 2 vectors of indices, at each iteration populate the updated set of
164+ // indices based on the old set of indices, and the size of the small vector
165+ // in the current iteration.
166+ SmallVector<int64_t > indices{0 };
167+ SmallVector<int64_t > nextIndices;
168+ int64_t stride = 1 ;
169+ for (int i = large.size () - 1 ; i >= 0 ; --i) {
170+ auto currentSize = indices.size ();
171+ auto smallSize = getSmall (i);
172+ auto nextSize = currentSize * smallSize;
173+ nextIndices.resize (nextSize);
174+ int64_t *base = nextIndices.begin ();
175+ int64_t offset = getOffset (i) * stride;
176+ for (int j = 0 ; j < smallSize; ++j) {
177+ for (uint64_t k = 0 ; k < currentSize; ++k) {
178+ base[k] = indices[k] + offset;
179+ }
180+ offset += stride;
181+ base += currentSize;
182+ }
183+ stride *= large[i];
184+ std::swap (indices, nextIndices);
185+ nextIndices.clear ();
186+ }
187+ return indices;
188+ }
189+
190+ // / This pattern converts a vector.extract_strided_slice operation into a
191+ // / vector.shuffle operation that has a rank-1 (linearized) operand and result.
192+ // /
193+ // / For example, the following:
194+ // /
195+ // / ```
115196// / vector.extract_strided_slice %source
116197// / { offsets = [..], strides = [..], sizes = [..] }
198+ // / ```
199+ // /
117200// / is converted to :
201+ // / ```
118202// / %source_1d = vector.shape_cast %source
119- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
120- // / %out_nd = vector.shape_cast %out_1d
121- // / `shuffle_indices_1d` is computed using the offsets and sizes of the
122- // / extraction.
203+ // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
204+ // / %out_nd = vector.shape_cast %out_1d
205+ // / ```
206+ // /
207+ // / `shuffle_indices_1d` is computed using the offsets and sizes of the original
208+ // / vector.extract_strided_slice operation.
123209struct LinearizeVectorExtractStridedSlice final
124210 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
125211 using OpConversionPattern::OpConversionPattern;
@@ -129,88 +215,109 @@ struct LinearizeVectorExtractStridedSlice final
129215 : OpConversionPattern(typeConverter, context, benefit) {}
130216
131217 LogicalResult
132- matchAndRewrite (vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
218+ matchAndRewrite (vector::ExtractStridedSliceOp extractStridedSliceOp,
219+ OpAdaptor adaptor,
133220 ConversionPatternRewriter &rewriter) const override {
134- VectorType dstType =
135- getTypeConverter ()->convertType <VectorType>(extractOp.getType ());
136- assert (dstType && " vector type destination expected." );
137- if (extractOp.getVector ().getType ().isScalable () || dstType.isScalable ())
138- return rewriter.notifyMatchFailure (extractOp,
139- " scalable vectors are not supported." );
140221
141- ArrayAttr offsets = extractOp.getOffsets ();
142- ArrayAttr sizes = extractOp.getSizes ();
143- ArrayAttr strides = extractOp.getStrides ();
144- if (!isConstantIntValue (strides[0 ], 1 ))
145- return rewriter.notifyMatchFailure (
146- extractOp, " Strided slice with stride != 1 is not supported." );
147- Value srcVector = adaptor.getVector ();
148- // If kD offsets are specified for nD source vector (n > k), the granularity
149- // of the extraction is greater than 1. In this case last (n-k) dimensions
150- // form the extraction granularity.
151- // Example :
152- // vector.extract_strided_slice %src {
153- // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
154- // vector<4x8x8xf32> to vector<2x2x8xf32>
155- // Here, extraction granularity is 8.
156- int64_t extractGranularitySize = 1 ;
157- int64_t nD = extractOp.getSourceVectorType ().getRank ();
158- int64_t kD = (int64_t )offsets.size ();
159- int64_t k = kD ;
160- while (k < nD) {
161- extractGranularitySize *= extractOp.getSourceVectorType ().getShape ()[k];
162- ++k;
222+ VectorType flatOutputType = getTypeConverter ()->convertType <VectorType>(
223+ extractStridedSliceOp.getType ());
224+ assert (flatOutputType && " vector type expected" );
225+
226+ if (!stridesAllOne (extractStridedSliceOp)) {
227+ return rewriter.notifyMatchFailure (extractStridedSliceOp,
228+ " strides other than 1 not supported" );
163229 }
164- // Get total number of extracted slices.
165- int64_t nExtractedSlices = 1 ;
166- for (Attribute size : sizes) {
167- nExtractedSlices *= cast<IntegerAttr>(size).getInt ();
230+
231+ ArrayRef<int64_t > inputShape =
232+ extractStridedSliceOp.getSourceVectorType ().getShape ();
233+
234+ ArrayRef<int64_t > outputType = extractStridedSliceOp.getType ().getShape ();
235+
236+ auto maybeIntOffsets =
237+ intsFromArrayAttr (extractStridedSliceOp.getOffsets ());
238+ if (failed (maybeIntOffsets)) {
239+ return rewriter.notifyMatchFailure (extractStridedSliceOp,
240+ " failed to get integer offsets" );
168241 }
169- // Compute the strides of the source vector considering first k dimensions.
170- llvm::SmallVector<int64_t , 4 > sourceStrides (kD , extractGranularitySize);
171- for (int i = kD - 2 ; i >= 0 ; --i) {
172- sourceStrides[i] = sourceStrides[i + 1 ] *
173- extractOp.getSourceVectorType ().getShape ()[i + 1 ];
242+
243+ SmallVector<int64_t > indices = getFlattenedStridedSliceIndices (
244+ outputType, inputShape, maybeIntOffsets.value ());
245+
246+ Value srcVector = adaptor.getVector ();
247+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
248+ extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
249+ return success ();
250+ }
251+ };
252+
253+ // / This pattern converts a vector.insert_strided_slice operation into a
254+ // / vector.shuffle operation that has rank-1 (linearized) operands and result.
255+ // /
256+ // / For example, the following:
257+ // / ```
258+ // / %0 = vector.insert_strided_slice %to_store, %into
259+ // / {offsets = [1, 0, 0, 0], strides = [1, 1]}
260+ // / : vector<2x2xi8> into vector<2x1x3x2xi8>
261+ // / ```
262+ // /
263+ // / is converted to
264+ // / ```
265+ // / %to_store_1d
266+ // / = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
267+ // / %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
268+ // / %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
269+ // / %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
270+ // / ```
271+ // /
272+ // / where shuffle_indices_1d in this case is
273+ // / [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
274+ // / ^^^^^^^^^^^^^^
275+ // / to_store_1d
276+ // /
277+ struct LinearizeVectorInsertStridedSlice final
278+ : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
279+ using OpConversionPattern::OpConversionPattern;
280+ LinearizeVectorInsertStridedSlice (const TypeConverter &typeConverter,
281+ MLIRContext *context,
282+ PatternBenefit benefit = 1 )
283+ : OpConversionPattern(typeConverter, context, benefit) {}
284+
285+ LogicalResult
286+ matchAndRewrite (vector::InsertStridedSliceOp insertStridedSliceOp,
287+ OpAdaptor adaptor,
288+ ConversionPatternRewriter &rewriter) const override {
289+
290+ if (!stridesAllOne (insertStridedSliceOp)) {
291+ return rewriter.notifyMatchFailure (insertStridedSliceOp,
292+ " strides other than 1 not supported" );
174293 }
175- // Final shuffle indices has nExtractedSlices * extractGranularitySize
176- // elements.
177- llvm::SmallVector<int64_t , 4 > indices (nExtractedSlices *
178- extractGranularitySize);
179- // Compute the strides of the extracted kD vector.
180- llvm::SmallVector<int64_t , 4 > extractedStrides (kD , 1 );
181- // Compute extractedStrides.
182- for (int i = kD - 2 ; i >= 0 ; --i) {
183- extractedStrides[i] =
184- extractedStrides[i + 1 ] * cast<IntegerAttr>(sizes[i + 1 ]).getInt ();
294+
295+ VectorType inputType = insertStridedSliceOp.getValueToStore ().getType ();
296+ ArrayRef<int64_t > inputShape = inputType.getShape ();
297+
298+ VectorType outputType = insertStridedSliceOp.getType ();
299+ ArrayRef<int64_t > outputShape = outputType.getShape ();
300+ int64_t nOutputElements = outputType.getNumElements ();
301+
302+ auto maybeIntOffsets = intsFromArrayAttr (insertStridedSliceOp.getOffsets ());
303+ if (failed (maybeIntOffsets)) {
304+ return rewriter.notifyMatchFailure (insertStridedSliceOp,
305+ " failed to get integer offsets" );
185306 }
186- // Iterate over all extracted slices from 0 to nExtractedSlices - 1
187- // and compute the multi-dimensional index and the corresponding linearized
188- // index within the source vector.
189- for (int64_t i = 0 ; i < nExtractedSlices; ++i) {
190- int64_t index = i;
191- // Compute the corresponding multi-dimensional index.
192- llvm::SmallVector<int64_t , 4 > multiDimIndex (kD , 0 );
193- for (int64_t j = 0 ; j < kD ; ++j) {
194- multiDimIndex[j] = (index / extractedStrides[j]);
195- index -= multiDimIndex[j] * extractedStrides[j];
196- }
197- // Compute the corresponding linearized index in the source vector
198- // i.e. shift the multiDimIndex by the offsets.
199- int64_t linearizedIndex = 0 ;
200- for (int64_t j = 0 ; j < kD ; ++j) {
201- linearizedIndex +=
202- (cast<IntegerAttr>(offsets[j]).getInt () + multiDimIndex[j]) *
203- sourceStrides[j];
204- }
205- // Fill the indices array form linearizedIndex to linearizedIndex +
206- // extractGranularitySize.
207- for (int64_t j = 0 ; j < extractGranularitySize; ++j) {
208- indices[i * extractGranularitySize + j] = linearizedIndex + j;
209- }
307+ SmallVector<int64_t > sliceIndices = getFlattenedStridedSliceIndices (
308+ inputShape, outputShape, maybeIntOffsets.value ());
309+
310+ SmallVector<int64_t > indices (nOutputElements, 0 );
311+ std::iota (indices.begin (), indices.end (), 0 );
312+ for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices)) {
313+ indices[sliceIndex] = index + nOutputElements;
210314 }
211- // Perform a shuffle to extract the kD vector.
212- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
213- extractOp, dstType, srcVector, srcVector, indices);
315+
316+ Value flatToStore = adaptor.getValueToStore ();
317+ Value flatDest = adaptor.getDest ();
318+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(insertStridedSliceOp,
319+ flatDest.getType (), flatDest,
320+ flatToStore, indices);
214321 return success ();
215322 }
216323};
@@ -296,7 +403,7 @@ struct LinearizeVectorExtract final
296403 // Skip if result is not a vector type
297404 if (!isa<VectorType>(extractOp.getType ()))
298405 return rewriter.notifyMatchFailure (extractOp,
299- " scalar extract is not supported. " );
406+ " scalar extract not supported" );
300407 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
301408 assert (dstTy && " expected 1-D vector type" );
302409
@@ -539,6 +646,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
539646 const TypeConverter &typeConverter, const ConversionTarget &target,
540647 RewritePatternSet &patterns) {
541648 patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
542- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
543- typeConverter, patterns.getContext ());
649+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
650+ LinearizeVectorInsertStridedSlice>(typeConverter,
651+ patterns.getContext ());
544652}
0 commit comments