Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 37 additions & 76 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
/// %res = vector.transfer_write %vectorToStore into %dest
/// %res = vector.transfer_write %vecToStore into %dest
///
/// If the leading N dimensions of the vector to store do not match
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
/// masking is applied to ensure correctness:
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
/// vector.transfer_write %vectorToStore into %dest
/// vector.transfer_write %vecToStore into %dest
/// }
///
/// The mask shape is identical to `vectorToStore` (with the element type ==
/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
/// %write = vector.transfer_write %vectorToStore into %dest
/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
/// to 0.
///
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
/// `valueToStore`.
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
/// already provided in `vectorToStore`.
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
/// are set to 0.
static Operation *
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
Value dest,
ArrayRef<int64_t> inputVecSizesForLeadingDims,
SmallVector<Value> writeIndices = {},
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {

ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();

VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();

// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
// In this case, assume that all the required vector sizes have been
// provided.
assert(inputVecSizesForLeadingDims.size() ==
static_cast<size_t>(vecToStoreType.getRank()) &&
"Insufficient number of input vector sizes!");
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the write indices.
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
(destShape[i] >= inputVecSizesForLeadingDims[i]) &&
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
}

Expand All @@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
/*vector=*/vectorToStore,
/*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
Expand All @@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;

assert(llvm::none_of(
destShape.drop_front(inputVecSizesForLeadingDims.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");

// Check if masking is needed.
bool needMaskForWrite =
!llvm::equal(inputVecSizesForLeadingDims,
destShape.take_front(destRank - vecToStoreRank +
inputVecSizesForLeadingDims.size()));

// If masking is needed, generate the mask and mask the operation.
if (needMaskForWrite) {
// Get the mask shape + type. Missing mask dimensions are taken from
// `vectorToStore`.
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
inputVecSizesForLeadingDims.end());
if (vecToStoreRank >
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
writeMaskShape.append(vecToStoreShape.begin() +
inputVecSizesForLeadingDims.size(),
vecToStoreShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());

SmallVector<OpFoldResult> destSizes =
tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
destSizes.end());

if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
writeMaskShape))
return write;

Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
loc, writeMaskType, maskSizes);
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
// Check if masking is needed. If not, exit.
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
return write;

// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());

SmallVector<OpFoldResult> destSizes =
tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());

if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
vecToStoreShape))
return write;

return write;
Value maskForWrite =
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
return mlir::vector::maskOperation(builder, write, maskForWrite);
}

/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
Expand Down Expand Up @@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
Expand Down Expand Up @@ -1999,9 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -3043,9 +3004,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
writeIndices, inputVectorSizes.empty());

// 4. Finalize
newResults.push_back(write->getResult(0));
Expand Down
Loading