diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index cf6efaa04ae44..ca71f701f97fa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -198,85 +198,156 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, return *newMask; } -/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for -/// emitting `vector.extract_strided_slice`. +/// Extracts 1-D subvector from a 1-D vector. +/// +/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements +/// from `src`, starting at `offset`. The result is also a rank-1 vector: +/// +/// vector +/// +/// (`!elType` is the element type of the source vector). As `offset` is a known +/// _static_ value, this helper hook emits `vector.extract_strided_slice`. +/// +/// EXAMPLE: +/// %res = vector.extract_strided_slice %src +/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] } static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, - Value source, int64_t frontOffset, - int64_t subvecSize) { - auto vectorType = cast(source.getType()); - assert(vectorType.getRank() == 1 && "expected 1-D source types"); - assert(frontOffset + subvecSize <= vectorType.getNumElements() && + Value src, int64_t offset, + int64_t numElemsToExtract) { + auto vectorType = cast(src.getType()); + assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector "); + assert(offset + numElemsToExtract <= vectorType.getNumElements() && "subvector out of bounds"); - // do not need extraction if the subvector size is the same as the source - if (vectorType.getNumElements() == subvecSize) - return source; + // When extracting all available elements, just use the source vector as the + // result. + if (vectorType.getNumElements() == numElemsToExtract) + return src; - auto offsets = rewriter.getI64ArrayAttr({frontOffset}); - auto sizes = rewriter.getI64ArrayAttr({subvecSize}); + auto offsets = rewriter.getI64ArrayAttr({offset}); + auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract}); auto strides = rewriter.getI64ArrayAttr({1}); auto resultVectorType = - VectorType::get({subvecSize}, vectorType.getElementType()); + VectorType::get({numElemsToExtract}, vectorType.getElementType()); return rewriter - .create(loc, resultVectorType, source, + .create(loc, resultVectorType, src, offsets, sizes, strides) ->getResult(0); } -/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting -/// at `offset`. it is a wrapper function for emitting +/// Inserts 1-D subvector into a 1-D vector. +/// +/// Inserts the input rank-1 source vector into the destination vector starting +/// at `offset`. As `offset` is a known _static_ value, this helper hook emits /// `vector.insert_strided_slice`. +/// +/// EXAMPLE: +/// %res = vector.insert_strided_slice %src, %dest +/// {offsets = [%offset], strides [1]} static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset) { - [[maybe_unused]] auto srcType = cast(src.getType()); - [[maybe_unused]] auto destType = cast(dest.getType()); - assert(srcType.getRank() == 1 && destType.getRank() == 1 && - "expected source and dest to be vector type"); + auto srcVecTy = cast(src.getType()); + auto destVecTy = cast(dest.getType()); + assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 && + "expected source and dest to be rank-1 vector types"); + + // If overwritting the destination vector, just return the source. + if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0) + return src; + auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); - return rewriter.create(loc, dest.getType(), src, + return rewriter.create(loc, destVecTy, src, dest, offsets, strides); } -/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset` -/// and size `numElementsToExtract`, and inserts into the `dest` vector. This -/// function emits multiple `vector.extract` and `vector.insert` ops, so only -/// use it when `offset` cannot be folded into a constant value. +/// Extracts 1-D subvector from a 1-D vector. +/// +/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements +/// from `src`, starting at `offset`. The result is also a rank-1 vector: +/// +/// vector +/// +/// (`!elType` is the element type of the source vector). As `offset` is assumed +/// to be a _dynamic_ SSA value, this helper method generates a sequence of +/// `vector.extract` + `vector.insert` pairs. +/// +/// EXAMPLE: +/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2> +/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2> +/// %c1 = arith.constant 1 : index +/// %idx2 = arith.addi %offset, %c1 : index +/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2> +/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2> +/// (...) static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, - Value source, Value dest, + Value src, Value dest, OpFoldResult offset, - int64_t numElementsToExtract) { - assert(isa(source) && "expected `source` to be a vector type"); - for (int i = 0; i < numElementsToExtract; ++i) { + int64_t numElemsToExtract) { + auto srcVecTy = cast(src.getType()); + assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector "); + // NOTE: We are unable to take the offset into account in the following + // assert, hence its still possible that the subvector is out-of-bounds even + // if the condition is true. + assert(numElemsToExtract <= srcVecTy.getNumElements() && + "subvector out of bounds"); + + // When extracting all available elements, just use the source vector as the + // result. + if (srcVecTy.getNumElements() == numElemsToExtract) + return src; + + for (int i = 0; i < numElemsToExtract; ++i) { Value extractLoc = (i == 0) ? offset.dyn_cast() : rewriter.create( loc, rewriter.getIndexType(), offset.dyn_cast(), rewriter.create(loc, i)); - auto extractOp = - rewriter.create(loc, source, extractLoc); + auto extractOp = rewriter.create(loc, src, extractLoc); dest = rewriter.create(loc, extractOp, dest, i); } return dest; } -/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`. +/// Inserts 1-D subvector into a 1-D vector. +/// +/// Inserts the input rank-1 source vector into the destination vector starting +/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook +/// uses a sequence of `vector.extract` + `vector.insert` pairs. +/// +/// EXAMPLE: +/// %v1 = vector.extract %src[0] : i2 from vector<8xi2> +/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2> +/// %c1 = arith.constant 1 : index +/// %idx2 = arith.addi %offset, %c1 : index +/// %v2 = vector.extract %src[1] : i2 from vector<8xi2> +/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2> +/// (...) static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, - Value source, Value dest, - OpFoldResult destOffsetVar, - size_t length) { - assert(isa(source) && "expected `source` to be a vector type"); - assert(length > 0 && "length must be greater than 0"); - Value destOffsetVal = - getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar); - for (size_t i = 0; i < length; ++i) { + Value src, Value dest, + OpFoldResult offset, + int64_t numElemsToInsert) { + auto srcVecTy = cast(src.getType()); + auto destVecTy = cast(dest.getType()); + assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 && + "expected source and dest to be rank-1 vector types"); + assert(numElemsToInsert > 0 && + "the number of elements to insert must be greater than 0"); + // NOTE: We are unable to take the offset into account in the following + // assert, hence its still possible that the subvector is out-of-bounds even + // if the condition is true. + assert(numElemsToInsert <= destVecTy.getNumElements() && + "subvector out of bounds"); + + Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset); + for (int64_t i = 0; i < numElemsToInsert; ++i) { auto insertLoc = i == 0 ? destOffsetVal : rewriter.create( loc, rewriter.getIndexType(), destOffsetVal, rewriter.create(loc, i)); - auto extractOp = rewriter.create(loc, source, i); + auto extractOp = rewriter.create(loc, src, i); dest = rewriter.create(loc, extractOp, dest, insertLoc); } return dest;