Skip to content

Commit 560f397

Browse files
committed
use assumeDynamicDimsMatchVecSizes flag for named contractions
Signed-off-by: Ege Beysel <[email protected]>
1 parent f2234f6 commit 560f397

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
21812181

21822182
// Load operands.
21832183
SmallVector<Value> vecOperands;
2184+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
21842185
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
21852186
// The operand vector shape is computed by mapping the canonical vector
21862187
// shape to the operand's domain. Further permutations are left as a part of
@@ -2191,12 +2192,22 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
21912192
Type elemType = getElementTypeOrSelf(opOperand.get());
21922193
VectorType readType =
21932194
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2195+
2196+
SmallVector<Value> indices(linalgOp.getShape(&opOperand).size(), zero);
2197+
Operation *read = vector::TransferReadOp::create(
2198+
rewriter, loc, readType, opOperand.get(), indices,
2199+
/*padding=*/std::nullopt, readMap);
2200+
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
2201+
Value readValue = read->getResult(0);
21942202

2195-
Value read = mlir::vector::createReadOrMaskedRead(
2196-
rewriter, loc, opOperand.get(), readType.getShape(),
2197-
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2198-
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2199-
vecOperands.push_back(read);
2203+
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
2204+
// will be in-bounds.
2205+
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
2206+
SmallVector<bool> inBounds(readType.getRank(), true);
2207+
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
2208+
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
2209+
}
2210+
vecOperands.push_back(readValue);
22002211
}
22012212

22022213
// Remap iterators from linalg to vector.

0 commit comments

Comments
 (0)