Skip to content

Commit 04df359

Browse files
authored
[mlir][vector] Simplify createReadOrMaskedRead (#163736)
Simplify `createReadOrMaskedRead` to only require _one_ argument to specify the vector type to read (passed as `VectorType`) instead of passing vector-sizes and scalable-flags independently (i.e. _two_ arguments). A simple overload is provided for users that wouldn't re-use the corresponding `VectorType` (and hence there's no point for them to create). While there are no users upstream for this overload, it's been helpful downstream.
1 parent 2690f05 commit 04df359

File tree

3 files changed

+50
-29
lines changed

3 files changed

+50
-29
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,17 @@ bool isLinearizableVector(VectorType type);
219219

220220
/// Creates a TransferReadOp from `source`.
221221
///
222-
/// The shape of the vector to read is specified via `inputVectorSizes`. If the
223-
/// shape of the output vector differs from the shape of the value being read,
224-
/// masking is used to avoid out-of-bounds accesses. Set
222+
/// If the shape of vector to read differs from the shape of the value being
223+
/// read, masking is used to avoid out-of-bounds accesses. Set
225224
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
226225
/// instead of explicit masks.
227226
///
228227
/// Note: all read offsets are set to 0.
228+
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229+
const VectorType &vecToReadTy,
230+
std::optional<Value> padValue = std::nullopt,
231+
bool useInBoundsInsteadOfMasking = false);
232+
229233
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
230234
ArrayRef<int64_t> inputVectorSizes,
231235
std::optional<Value> padValue = std::nullopt,

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,9 +1890,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18901890

18911891
// Create masked TransferReadOp.
18921892
auto maskedRead = vector::createReadOrMaskedRead(
1893-
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
1894-
useInBoundsInsteadOfMasking,
1895-
/*inputScalableVecSizes=*/{});
1893+
rewriter, loc, packOp.getSource(), readVecType, padValue,
1894+
useInBoundsInsteadOfMasking);
18961895

18971896
// Create ShapeCastOp.
18981897
auto shapeCastOp = vector::ShapeCastOp::create(
@@ -1977,9 +1976,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19771976
}
19781977

19791978
// -- Generate the read operation --
1979+
VectorType readVecType =
1980+
VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981+
readScalableVectorFlags);
19801982
Value readResult = vector::createReadOrMaskedRead(
1981-
rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
1982-
useInBoundsInsteadOfMasking, readScalableVectorFlags);
1983+
rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984+
useInBoundsInsteadOfMasking);
19831985

19841986
// -- Generate the transpose operation --
19851987
PackingMetadata packMetadata;
@@ -2025,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20252027
.reifyResultShapes(rewriter, reifiedReturnShapes);
20262028
(void)status; // prevent unused variable warning on non-assert builds
20272029
assert(succeeded(status) && "failed to reify result shapes");
2030+
auto readType = VectorType::get(inputVectorSizes, padValue.getType());
20282031
auto maskedRead = vector::createReadOrMaskedRead(
2029-
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2030-
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
2032+
rewriter, loc, padOp.getSource(), readType, padValue,
2033+
/*useInBoundsInsteadOfMasking=*/false);
20312034

20322035
// Create Xfer write Op
20332036
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2222,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
22222225
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
22232226

22242227
Value read = mlir::vector::createReadOrMaskedRead(
2225-
rewriter, loc, opOperand.get(), readType.getShape(),
2228+
rewriter, loc, opOperand.get(), readType,
22262229
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2227-
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2230+
/*useInBoundsInsteadOfMasking=*/false);
22282231
vecOperands.push_back(read);
22292232
}
22302233

@@ -3165,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31653168
SmallVector<Value> readIndices(
31663169
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31673170
Value read = mlir::vector::createReadOrMaskedRead(
3168-
rewriter, loc, source, vecType.getShape(), padValue,
3169-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3170-
/*inputScalableVecSizes=*/{});
3171+
rewriter, loc, source, vecType, padValue,
3172+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
31713173

31723174
// Create write
31733175
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
322322
std::optional<Value> padValue,
323323
bool useInBoundsInsteadOfMasking,
324324
ArrayRef<bool> inputScalableVecDims) {
325-
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
325+
VectorType vecToReadTy = VectorType::get(
326+
inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(),
327+
inputScalableVecDims);
328+
329+
return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue,
330+
useInBoundsInsteadOfMasking);
331+
}
332+
333+
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
334+
Value source,
335+
const VectorType &vecToReadTy,
336+
std::optional<Value> padValue,
337+
bool useInBoundsInsteadOfMasking) {
338+
assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
339+
ShapedType::kDynamic) &&
326340
"invalid input vector sizes");
327341
auto sourceShapedType = cast<ShapedType>(source.getType());
328342
auto sourceShape = sourceShapedType.getShape();
329-
assert(sourceShape.size() == inputVectorSizes.size() &&
343+
344+
int64_t vecToReadRank = vecToReadTy.getRank();
345+
auto vecToReadShape = vecToReadTy.getShape();
346+
347+
assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
330348
"expected same ranks.");
331-
auto vectorType =
332-
VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
333-
inputScalableVecDims);
334349
assert((!padValue.has_value() ||
335350
padValue.value().getType() == sourceShapedType.getElementType()) &&
336351
"expected same pad element type to match source element type");
337-
int64_t readRank = inputVectorSizes.size();
352+
338353
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
339-
SmallVector<bool> inBoundsVal(readRank, true);
354+
SmallVector<bool> inBoundsVal(vecToReadRank, true);
340355

341356
if (useInBoundsInsteadOfMasking) {
342357
// Update the inBounds attribute.
343358
// FIXME: This computation is too weak - it ignores the read indices.
344-
for (unsigned i = 0; i < readRank; i++)
345-
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
359+
for (unsigned i = 0; i < vecToReadRank; i++)
360+
inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
346361
ShapedType::isStatic(sourceShape[i]);
347362
}
348363
auto transferReadOp = vector::TransferReadOp::create(
349364
builder, loc,
350-
/*vectorType=*/vectorType,
365+
/*vectorType=*/vecToReadTy,
351366
/*source=*/source,
352-
/*indices=*/SmallVector<Value>(readRank, zero),
367+
/*indices=*/SmallVector<Value>(vecToReadRank, zero),
353368
/*padding=*/padValue,
354369
/*inBounds=*/inBoundsVal);
355370

356-
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
371+
if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
372+
useInBoundsInsteadOfMasking)
357373
return transferReadOp;
358374
SmallVector<OpFoldResult> mixedSourceDims =
359375
isa<MemRefType>(source.getType())
360376
? memref::getMixedSizes(builder, loc, source)
361377
: tensor::getMixedSizes(builder, loc, source);
362378

363-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
364-
inputScalableVecDims);
379+
auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
365380
Value mask =
366381
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
367382
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)