Skip to content

Commit 2845025

Browse files
committed
[mlir][vector] Simplify createReadOrMaskedRead
Simplify `createReadOrMaskedRead` to only require _one_ argument to specify the vector type to read (passed as `VectorType`) instead of passing vector-sizes and scalable-flags independently (i.e. _two_ arguments).
1 parent c08644c commit 2845025

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,16 @@ bool isLinearizableVector(VectorType type);
219219

220220
/// Creates a TransferReadOp from `source`.
221221
///
222-
/// The shape of the vector to read is specified via `inputVectorSizes`. If the
223-
/// shape of the output vector differs from the shape of the value being read,
224-
/// masking is used to avoid out-of-bounds accesses. Set
222+
/// If the shape of vector to read differs from the shape of the value being
223+
/// read, masking is used to avoid out-of-bounds accesses. Set
225224
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
226225
/// instead of explicit masks.
227226
///
228227
/// Note: all read offsets are set to 0.
229228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
230-
ArrayRef<int64_t> inputVectorSizes,
229+
VectorType &vecToReadTy,
231230
std::optional<Value> padValue = std::nullopt,
232-
bool useInBoundsInsteadOfMasking = false,
233-
ArrayRef<bool> inputScalableVecDims = {});
231+
bool useInBoundsInsteadOfMasking = false);
234232

235233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
236234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,9 +1890,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18901890

18911891
// Create masked TransferReadOp.
18921892
auto maskedRead = vector::createReadOrMaskedRead(
1893-
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
1894-
useInBoundsInsteadOfMasking,
1895-
/*inputScalableVecSizes=*/{});
1893+
rewriter, loc, packOp.getSource(), readVecType, padValue,
1894+
useInBoundsInsteadOfMasking);
18961895

18971896
// Create ShapeCastOp.
18981897
auto shapeCastOp = vector::ShapeCastOp::create(
@@ -1977,9 +1976,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19771976
}
19781977

19791978
// -- Generate the read operation --
1979+
VectorType readVecType =
1980+
VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981+
readScalableVectorFlags);
19801982
Value readResult = vector::createReadOrMaskedRead(
1981-
rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
1982-
useInBoundsInsteadOfMasking, readScalableVectorFlags);
1983+
rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984+
useInBoundsInsteadOfMasking);
19831985

19841986
// -- Generate the transpose operation --
19851987
PackingMetadata packMetadata;
@@ -2025,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20252027
.reifyResultShapes(rewriter, reifiedReturnShapes);
20262028
(void)status; // prevent unused variable warning on non-assert builds
20272029
assert(succeeded(status) && "failed to reify result shapes");
2030+
auto readType = VectorType::get(inputVectorSizes, padValue.getType());
20282031
auto maskedRead = vector::createReadOrMaskedRead(
2029-
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2030-
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
2032+
rewriter, loc, padOp.getSource(), readType, padValue,
2033+
/*useInBoundsInsteadOfMasking=*/false);
20312034

20322035
// Create Xfer write Op
20332036
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2222,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
22222225
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
22232226

22242227
Value read = mlir::vector::createReadOrMaskedRead(
2225-
rewriter, loc, opOperand.get(), readType.getShape(),
2228+
rewriter, loc, opOperand.get(), readType,
22262229
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2227-
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2230+
/*useInBoundsInsteadOfMasking=*/false);
22282231
vecOperands.push_back(read);
22292232
}
22302233

@@ -3165,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31653168
SmallVector<Value> readIndices(
31663169
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31673170
Value read = mlir::vector::createReadOrMaskedRead(
3168-
rewriter, loc, source, vecType.getShape(), padValue,
3169-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3170-
/*inputScalableVecSizes=*/{});
3171+
rewriter, loc, source, vecType, padValue,
3172+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
31713173

31723174
// Create write
31733175
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -317,51 +317,51 @@ bool vector::isLinearizableVector(VectorType type) {
317317
}
318318

319319
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
320-
Value source,
321-
ArrayRef<int64_t> inputVectorSizes,
320+
Value source, VectorType &vecToReadTy,
322321
std::optional<Value> padValue,
323-
bool useInBoundsInsteadOfMasking,
324-
ArrayRef<bool> inputScalableVecDims) {
325-
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
322+
bool useInBoundsInsteadOfMasking) {
323+
assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
324+
ShapedType::kDynamic) &&
326325
"invalid input vector sizes");
327326
auto sourceShapedType = cast<ShapedType>(source.getType());
328327
auto sourceShape = sourceShapedType.getShape();
329-
assert(sourceShape.size() == inputVectorSizes.size() &&
328+
329+
int64_t vecToReadRank = vecToReadTy.getRank();
330+
auto vecToReadShape = vecToReadTy.getShape();
331+
332+
assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
330333
"expected same ranks.");
331-
auto vectorType =
332-
VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
333-
inputScalableVecDims);
334334
assert((!padValue.has_value() ||
335335
padValue.value().getType() == sourceShapedType.getElementType()) &&
336336
"expected same pad element type to match source element type");
337-
int64_t readRank = inputVectorSizes.size();
337+
338338
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
339-
SmallVector<bool> inBoundsVal(readRank, true);
339+
SmallVector<bool> inBoundsVal(vecToReadRank, true);
340340

341341
if (useInBoundsInsteadOfMasking) {
342342
// Update the inBounds attribute.
343343
// FIXME: This computation is too weak - it ignores the read indices.
344-
for (unsigned i = 0; i < readRank; i++)
345-
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
344+
for (unsigned i = 0; i < vecToReadRank; i++)
345+
inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
346346
ShapedType::isStatic(sourceShape[i]);
347347
}
348348
auto transferReadOp = vector::TransferReadOp::create(
349349
builder, loc,
350-
/*vectorType=*/vectorType,
350+
/*vectorType=*/vecToReadTy,
351351
/*source=*/source,
352-
/*indices=*/SmallVector<Value>(readRank, zero),
352+
/*indices=*/SmallVector<Value>(vecToReadRank, zero),
353353
/*padding=*/padValue,
354354
/*inBounds=*/inBoundsVal);
355355

356-
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
356+
if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
357+
useInBoundsInsteadOfMasking)
357358
return transferReadOp;
358359
SmallVector<OpFoldResult> mixedSourceDims =
359360
isa<MemRefType>(source.getType())
360361
? memref::getMixedSizes(builder, loc, source)
361362
: tensor::getMixedSizes(builder, loc, source);
362363

363-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
364-
inputScalableVecDims);
364+
auto maskType = vecToReadTy.clone(builder.getI1Type());
365365
Value mask =
366366
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
367367
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)