@@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
198198 return *newMask;
199199}
200200
201- // / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
202- // / emitting `vector.extract_strided_slice`.
201+ // / Extracts 1-D subvector from a 1-D vector.
202+ // /
203+ // / Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204+ // / from `src`, starting at `offset`. The result is also a rank-1 vector:
205+ // /
206+ // / vector<numElemsToExtract x !elemType>
207+ // /
208+ // / (`!elType` is the element type of the source vector). As `offset` is a known
209+ // / _static_ value, this helper hook emits `vector.extract_strided_slice`.
210+ // /
211+ // / EXAMPLE:
212+ // / %res = vector.extract_strided_slice %src
213+ // / { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
203214static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
204- Value source , int64_t frontOffset ,
205- int64_t subvecSize ) {
206- auto vectorType = cast<VectorType>(source .getType ());
207- assert (vectorType.getRank () == 1 && " expected 1-D source types " );
208- assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
215+ Value src , int64_t offset ,
216+ int64_t numElemsToExtract ) {
217+ auto vectorType = cast<VectorType>(src .getType ());
218+ assert (vectorType.getRank () == 1 && " expected source to be rank- 1-D vector " );
219+ assert (offset + numElemsToExtract <= vectorType.getNumElements () &&
209220 " subvector out of bounds" );
210221
211- // do not need extraction if the subvector size is the same as the source
212- if (vectorType.getNumElements () == subvecSize)
213- return source;
222+ // When extracting all available elements, just use the source vector as the
223+ // result.
224+ if (vectorType.getNumElements () == numElemsToExtract)
225+ return src;
214226
215- auto offsets = rewriter.getI64ArrayAttr ({frontOffset });
216- auto sizes = rewriter.getI64ArrayAttr ({subvecSize });
227+ auto offsets = rewriter.getI64ArrayAttr ({offset });
228+ auto sizes = rewriter.getI64ArrayAttr ({numElemsToExtract });
217229 auto strides = rewriter.getI64ArrayAttr ({1 });
218230
219231 auto resultVectorType =
220- VectorType::get ({subvecSize }, vectorType.getElementType ());
232+ VectorType::get ({numElemsToExtract }, vectorType.getElementType ());
221233 return rewriter
222- .create <vector::ExtractStridedSliceOp>(loc, resultVectorType, source ,
234+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType, src ,
223235 offsets, sizes, strides)
224236 ->getResult (0 );
225237}
226238
227- // / Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
228- // / at `offset`. it is a wrapper function for emitting
239+ // / Inserts 1-D subvector into a 1-D vector.
240+ // /
241+ // / Inserts the input rank-1 source vector into the destination vector starting
242+ // / at `offset`. As `offset` is a known _static_ value, this helper hook emits
229243// / `vector.insert_strided_slice`.
244+ // /
245+ // / EXAMPLE:
246+ // / %res = vector.insert_strided_slice %src, %dest
247+ // / {offsets = [%offset], strides [1]}
230248static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
231249 Value src, Value dest, int64_t offset) {
232- [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
233- [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
234- assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
235- " expected source and dest to be vector type" );
250+ auto srcVecTy = cast<VectorType>(src.getType ());
251+ auto destVecTy = cast<VectorType>(dest.getType ());
252+ assert (srcVecTy.getRank () == 1 && destVecTy.getRank () == 1 &&
253+ " expected source and dest to be rank-1 vector types" );
254+
255+ // If overwritting the destination vector, just return the source.
256+ if (srcVecTy.getNumElements () == destVecTy.getNumElements () && offset == 0 )
257+ return src;
258+
236259 auto offsets = rewriter.getI64ArrayAttr ({offset});
237260 auto strides = rewriter.getI64ArrayAttr ({1 });
238- return rewriter.create <vector::InsertStridedSliceOp>(loc, dest. getType () , src,
261+ return rewriter.create <vector::InsertStridedSliceOp>(loc, destVecTy , src,
239262 dest, offsets, strides);
240263}
241264
242- // / Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
243- // / and size `numElementsToExtract`, and inserts into the `dest` vector. This
244- // / function emits multiple `vector.extract` and `vector.insert` ops, so only
245- // / use it when `offset` cannot be folded into a constant value.
265+ // / Extracts 1-D subvector from a 1-D vector.
266+ // /
267+ // / Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268+ // / from `src`, starting at `offset`. The result is also a rank-1 vector:
269+ // /
270+ // / vector<numElemsToExtact x !elType>
271+ // /
272+ // / (`!elType` is the element type of the source vector). As `offset` is assumed
273+ // / to be a _dynamic_ SSA value, this helper method generates a sequence of
274+ // / `vector.extract` + `vector.insert` pairs.
275+ // /
276+ // / EXAMPLE:
277+ // / %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278+ // / %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279+ // / %c1 = arith.constant 1 : index
280+ // / %idx2 = arith.addi %offset, %c1 : index
281+ // / %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282+ // / %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283+ // / (...)
246284static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
247- Value source , Value dest,
285+ Value src , Value dest,
248286 OpFoldResult offset,
249- int64_t numElementsToExtract) {
250- assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
251- for (int i = 0 ; i < numElementsToExtract; ++i) {
287+ int64_t numElemsToExtract) {
288+ auto srcVecTy = cast<VectorType>(src.getType ());
289+ assert (srcVecTy.getRank () == 1 && " expected source to be rank-1-D vector " );
290+ // NOTE: We are unable to take the offset into account in the following
291+ // assert, hence its still possible that the subvector is out-of-bounds even
292+ // if the condition is true.
293+ assert (numElemsToExtract <= srcVecTy.getNumElements () &&
294+ " subvector out of bounds" );
295+
296+ // When extracting all available elements, just use the source vector as the
297+ // result.
298+ if (srcVecTy.getNumElements () == numElemsToExtract)
299+ return src;
300+
301+ for (int i = 0 ; i < numElemsToExtract; ++i) {
252302 Value extractLoc =
253303 (i == 0 ) ? offset.dyn_cast <Value>()
254304 : rewriter.create <arith::AddIOp>(
255305 loc, rewriter.getIndexType (), offset.dyn_cast <Value>(),
256306 rewriter.create <arith::ConstantIndexOp>(loc, i));
257- auto extractOp =
258- rewriter.create <vector::ExtractOp>(loc, source, extractLoc);
307+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, src, extractLoc);
259308 dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, i);
260309 }
261310 return dest;
262311}
263312
264- // / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
313+ // / Inserts 1-D subvector into a 1-D vector.
314+ // /
315+ // / Inserts the input rank-1 source vector into the destination vector starting
316+ // / at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317+ // / uses a sequence of `vector.extract` + `vector.insert` pairs.
318+ // /
319+ // / EXAMPLE:
320+ // / %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321+ // / %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322+ // / %c1 = arith.constant 1 : index
323+ // / %idx2 = arith.addi %offset, %c1 : index
324+ // / %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325+ // / %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326+ // / (...)
265327static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
266- Value source, Value dest,
267- OpFoldResult destOffsetVar,
268- size_t length) {
269- assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
270- assert (length > 0 && " length must be greater than 0" );
271- Value destOffsetVal =
272- getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
273- for (size_t i = 0 ; i < length; ++i) {
328+ Value src, Value dest,
329+ OpFoldResult offset,
330+ int64_t numElemsToInsert) {
331+ auto srcVecTy = cast<VectorType>(src.getType ());
332+ auto destVecTy = cast<VectorType>(dest.getType ());
333+ assert (srcVecTy.getRank () == 1 && destVecTy.getRank () == 1 &&
334+ " expected source and dest to be rank-1 vector types" );
335+ assert (numElemsToInsert > 0 &&
336+ " the number of elements to insert must be greater than 0" );
337+ // NOTE: We are unable to take the offset into account in the following
338+ // assert, hence its still possible that the subvector is out-of-bounds even
339+ // if the condition is true.
340+ assert (numElemsToInsert <= destVecTy.getNumElements () &&
341+ " subvector out of bounds" );
342+
343+ Value destOffsetVal = getValueOrCreateConstantIndexOp (rewriter, loc, offset);
344+ for (int64_t i = 0 ; i < numElemsToInsert; ++i) {
274345 auto insertLoc = i == 0
275346 ? destOffsetVal
276347 : rewriter.create <arith::AddIOp>(
277348 loc, rewriter.getIndexType (), destOffsetVal,
278349 rewriter.create <arith::ConstantIndexOp>(loc, i));
279- auto extractOp = rewriter.create <vector::ExtractOp>(loc, source , i);
350+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, src , i);
280351 dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
281352 }
282353 return dest;
0 commit comments