Skip to content
58 changes: 31 additions & 27 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,41 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);

/// Return true if `vectorType` is a contiguous slice of `memrefType`.
/// Return true if `vectorType` is a contiguous slice of `memrefType`,
/// in the sense that it can be read/written from/to a contiguous area
/// of the memref.
///
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
/// checked (the other dims are not relevant). Note that for `vectorType` to be
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
/// to be contiguous - this is checked by looking at the corresponding strides.
/// The leading unit dimensions of the vector type are ignored as they
/// are not relevant to the result. Let N be the number of the vector
/// dimensions after ignoring a leading sequence of unit ones.
///
/// There might be some restriction on the leading dim of `VectorType`:
/// For `vectorType` to be a contiguous slice of `memrefType`
/// a) the N trailing dimensions of `memrefType` must be contiguous, and
/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
/// must match.
///
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
/// of `memrefType` then the leading dim of `vectorType` can be
/// arbitrary.
///
/// Ex. 1.1 contiguous slice, perfect match
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
///
/// Case 2. If an "internal" dim of `vectorType` does not match the
/// corresponding trailing dim in `memrefType` then the remaining
/// leading dims of `vectorType` have to be 1 (the first non-matching
/// dim can be arbitrary).
/// Examples:
///
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
/// Ex.1 contiguous slice, perfect match
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
/// Ex.3 non-contiguous slice, 2 != 3
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
/// dimension
/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
/// two dimensions, and it isn't
/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);

/// Returns an iterator for all positions in the leading dimensions of `vType`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
}

namespace {

/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
Expand Down Expand Up @@ -630,7 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();

int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
int64_t firstDimToCollapse =
sourceType.getRank() -
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();

// 1. Collapse the source memref
Value collapsedSource =
Expand Down Expand Up @@ -722,7 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();

int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
int64_t firstDimToCollapse =
sourceType.getRank() -
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();

// 1. Collapse the source memref
Value collapsedSource =
Expand Down
25 changes: 8 additions & 17 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;

ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();
// Ignore a leading sequence of adjacent unit dimensions in the vector.
ArrayRef<int64_t> vectorShape =
vectorType.getShape().drop_while([](auto v) { return v == 1; });
auto vecRank = vectorShape.size();

if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;

// Extract the trailing dims and strides of the input memref
// Extract the trailing dims of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);

// Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
memrefShape.rbegin(), memrefShape.rend());
if (firstNonMatchingDim.first == vectorShape.rend())
return true;

// One non-matching dim is still fine, however the remaining leading dims of
// `vectorType` need to be 1.
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
vectorShape.rend());

return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
// Compare the dims of `vectorType` against `memrefType`.
// All of the dimensions, except the first must match.
return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
}

std::optional<StaticTileOffsetRange>
Expand Down
Loading
Loading