Skip to content

Commit ca3e2b2

Browse files
committed
[mlir][vector] Avoid setting padding by default in vector transfer read, prefer ub.poisson
Signed-off-by: Fabian Mora <[email protected]>
1 parent e980523 commit ca3e2b2

File tree

15 files changed

+95
-76
lines changed

15 files changed

+95
-76
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
154154
Value lhs, Value rhs);
155155

156156
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
157+
158+
/// Creates an `arith.constant` operation with a zero value of type `type`.
159+
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
157160
} // namespace arith
158161
} // namespace mlir
159162

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
2121

2222
let useDefaultAttributePrinterParser = 1;
2323
let hasConstantMaterializer = 1;
24-
let dependentDialects = ["arith::ArithDialect"];
24+
let dependentDialects = [
25+
"arith::ArithDialect",
26+
"ub::UBDialect"
27+
];
2528
}
2629

2730
// Base class for Vector dialect ops.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,30 +1543,28 @@ def Vector_TransferReadOp :
15431543
}];
15441544

15451545
let builders = [
1546-
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
1546+
/// 1. Builder that sets padding to `padding` or poisson if not provided and
1547+
/// an empty mask (variant with attrs).
15471548
OpBuilder<(ins "VectorType":$vectorType,
15481549
"Value":$source,
15491550
"ValueRange":$indices,
1551+
"std::optional<Value>":$padding,
15501552
"AffineMapAttr":$permutationMapAttr,
15511553
"ArrayAttr":$inBoundsAttr)>,
1552-
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
1554+
/// 2. Builder that sets padding to `padding` or poisson if not provided and
1555+
/// an empty mask (variant without attrs).
15531556
OpBuilder<(ins "VectorType":$vectorType,
15541557
"Value":$source,
15551558
"ValueRange":$indices,
1559+
"std::optional<Value>":$padding,
15561560
"AffineMap":$permutationMap,
15571561
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
15581562
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
15591563
OpBuilder<(ins "VectorType":$vectorType,
15601564
"Value":$source,
15611565
"ValueRange":$indices,
1562-
"Value":$padding,
1563-
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1564-
/// 4. Builder that sets padding to zero and permutation map to
1565-
/// 'getMinorIdentityMap'.
1566-
OpBuilder<(ins "VectorType":$vectorType,
1567-
"Value":$source,
1568-
"ValueRange":$indices,
1569-
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1566+
"std::optional<Value>":$padding,
1567+
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
15701568
];
15711569

15721570
let extraClassDeclaration = [{

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
12571257
LLVM_DEBUG(permutationMap.print(dbgs()));
12581258

12591259
auto transfer = state.builder.create<vector::TransferReadOp>(
1260-
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
1260+
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, std::nullopt,
1261+
permutationMap);
12611262

12621263
// Register replacement for future uses in the scope.
12631264
state.registerOpVectorReplacement(loadOp, transfer);

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
292292
return false;
293293
}
294294

295+
Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
296+
Type type) {
297+
return builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(type));
298+
}
299+
295300
//===----------------------------------------------------------------------===//
296301
// AddIOp
297302
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
426426
// Create the new `transfer_read`.
427427
auto newReadOp = rewriter.create<vector::TransferReadOp>(
428428
readOp.getLoc(), collapsedVT, collapsedMem, indices,
429+
readOp.getPadding(),
429430
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
430431

431432
// Cast back to the original vector type.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11911191

11921192
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
11931193
loc, resultType, extractOp.getTensor(), transferReadIdxs,
1194+
arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
11941195
permutationMap, inBounds);
11951196

11961197
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1228,9 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271228
}
12281229

12291230
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1230-
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1231-
inBounds);
1231+
loc, resultType, extractOp.getTensor(), transferReadIdxs,
1232+
arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
1233+
permutationMap, inBounds);
12321234

12331235
LDBG("Vectorised as contiguous load: " << extractOp);
12341236
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1386,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13841386
/// performed to the maximal common vector size implied by the `linalgOp`
13851387
/// iteration space. This eager broadcasting is introduced in the
13861388
/// permutation_map of the vector.transfer_read operations. The eager
1387-
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
1389+
/// broadcasting makes it trivial to determine where broadcast, transposes and
13881390
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
13891391
/// the absence of good canonicalizations, the amount of work increases.
13901392
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1441,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14391441
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
14401442

14411443
Operation *read = rewriter.create<vector::TransferReadOp>(
1442-
loc, readType, opOperand->get(), indices, readMap);
1444+
loc, readType, opOperand->get(), indices,
1445+
arith::getZeroConstant(rewriter, loc, elemType), readMap);
14431446
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
14441447
Value readValue = read->getResult(0);
14451448

@@ -2641,6 +2644,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26412644

26422645
Value readValue = rewriter.create<vector::TransferReadOp>(
26432646
loc, readType, copyOp.getSource(), indices,
2647+
arith::getZeroConstant(rewriter, loc, srcElementType),
26442648
rewriter.getMultiDimIdentityMap(srcType.getRank()));
26452649
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
26462650
readValue =
@@ -3487,15 +3491,18 @@ struct Conv1DGenerator
34873491
SmallVector<Value> resPadding(resShape.size(), zero);
34883492

34893493
// Read the whole lhs, rhs and res in one shot (with zero padding).
3490-
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491-
lhsPadding);
3494+
Value lhs = rewriter.create<vector::TransferReadOp>(
3495+
loc, lhsType, lhsShaped, lhsPadding,
3496+
arith::getZeroConstant(rewriter, loc, lhsEltType));
34923497
// This is needed only for Conv.
34933498
Value rhs = nullptr;
34943499
if (oper == ConvOperationKind::Conv)
3495-
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496-
rhsPadding);
3497-
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3498-
resPadding);
3500+
rhs = rewriter.create<vector::TransferReadOp>(
3501+
loc, rhsType, rhsShaped, rhsPadding,
3502+
arith::getZeroConstant(rewriter, loc, rhsEltType));
3503+
Value res = rewriter.create<vector::TransferReadOp>(
3504+
loc, resType, resShaped, resPadding,
3505+
arith::getZeroConstant(rewriter, loc, resEltType));
34993506

35003507
// The base vectorization case for channeled convolution is input:
35013508
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3749,22 @@ struct Conv1DGenerator
37423749
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37433750
// 0].
37443751
Value lhs = rewriter.create<vector::TransferReadOp>(
3745-
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3752+
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3753+
arith::getZeroConstant(rewriter, loc, lhsEltType));
37463754
auto maybeMaskedLhs = maybeMaskXferOp(
37473755
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
37483756

37493757
// Read rhs slice of size {kw, c} @ [0, 0].
3750-
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751-
ValueRange{zero, zero});
3758+
Value rhs = rewriter.create<vector::TransferReadOp>(
3759+
loc, rhsType, rhsShaped, ValueRange{zero, zero},
3760+
arith::getZeroConstant(rewriter, loc, rhsEltType));
37523761
auto maybeMaskedRhs = maybeMaskXferOp(
37533762
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
37543763

37553764
// Read res slice of size {n, w, c} @ [0, 0, 0].
37563765
Value res = rewriter.create<vector::TransferReadOp>(
3757-
loc, resType, resShaped, ValueRange{zero, zero, zero});
3766+
loc, resType, resShaped, ValueRange{zero, zero, zero},
3767+
arith::getZeroConstant(rewriter, loc, resEltType));
37583768
auto maybeMaskedRes = maybeMaskXferOp(
37593769
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
37603770

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
42614261
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
42624262
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42634263
VectorType vectorType, Value source,
4264-
ValueRange indices, AffineMapAttr permutationMapAttr,
4264+
ValueRange indices, std::optional<Value> padding,
4265+
AffineMapAttr permutationMapAttr,
42654266
/*optional*/ ArrayAttr inBoundsAttr) {
4267+
42664268
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4267-
Value padding = builder.create<arith::ConstantOp>(
4268-
result.location, elemType, builder.getZeroAttr(elemType));
4269+
if (!padding)
4270+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
42694271
build(builder, result, vectorType, source, indices, permutationMapAttr,
4270-
padding, /*mask=*/Value(), inBoundsAttr);
4272+
*padding, /*mask=*/Value(), inBoundsAttr);
42714273
}
42724274

42734275
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
42744276
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42754277
VectorType vectorType, Value source,
4276-
ValueRange indices, AffineMap permutationMap,
4278+
ValueRange indices, std::optional<Value> padding,
4279+
AffineMap permutationMap,
42774280
std::optional<ArrayRef<bool>> inBounds) {
42784281
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
42794282
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
42804283
? builder.getBoolArrayAttr(inBounds.value())
42814284
: builder.getBoolArrayAttr(
42824285
SmallVector<bool>(vectorType.getRank(), false));
4283-
build(builder, result, vectorType, source, indices, permutationMapAttr,
4284-
inBoundsAttr);
4286+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4287+
if (!padding)
4288+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4289+
build(builder, result, vectorType, source, indices, *padding,
4290+
permutationMapAttr, inBoundsAttr);
42854291
}
42864292

42874293
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
42884294
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42894295
VectorType vectorType, Value source,
4290-
ValueRange indices, Value padding,
4296+
ValueRange indices, std::optional<Value> padding,
42914297
std::optional<ArrayRef<bool>> inBounds) {
42924298
AffineMap permutationMap = getTransferMinorIdentityMap(
42934299
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42964302
? builder.getBoolArrayAttr(inBounds.value())
42974303
: builder.getBoolArrayAttr(
42984304
SmallVector<bool>(vectorType.getRank(), false));
4305+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4306+
if (!padding)
4307+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
42994308
build(builder, result, vectorType, source, indices, permutationMapAttr,
4300-
padding,
4309+
*padding,
43014310
/*mask=*/Value(), inBoundsAttr);
43024311
}
43034312

4304-
/// 4. Builder that sets padding to zero and permutation map to
4305-
/// 'getMinorIdentityMap'.
4306-
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4307-
VectorType vectorType, Value source,
4308-
ValueRange indices,
4309-
std::optional<ArrayRef<bool>> inBounds) {
4310-
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4311-
Value padding = builder.create<arith::ConstantOp>(
4312-
result.location, elemType, builder.getZeroAttr(elemType));
4313-
build(builder, result, vectorType, source, indices, padding, inBounds);
4314-
}
4315-
43164313
template <typename EmitFun>
43174314
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
43184315
EmitFun emitOpError) {

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
173173
}
174174
SmallVector<bool> inBounds(indices.size(), true);
175175
return b.create<vector::TransferReadOp>(
176-
loc, cast<VectorType>(type), buffer, indices,
176+
loc, cast<VectorType>(type), buffer, indices, std::nullopt,
177177
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178178
}
179179

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern
660660
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
661661
vectorType.getElementType());
662662
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
663-
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
663+
loc, flatVectorType, collapsedSource, collapsedIndices,
664+
transferReadOp.getPadding(), collapsedMap);
664665
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
665666

666667
// 4. Replace the old transfer_read with the new one reading from the

0 commit comments

Comments
 (0)