Skip to content

Commit 4575f84

Browse files
committed
address reviewer comments
1 parent 7fc16e9 commit 4575f84

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,8 @@ def Vector_TransferReadOp :
15591559
"std::optional<Value>":$padding,
15601560
"AffineMap":$permutationMap,
15611561
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1562-
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
1562+
/// 3. Builder that sets padding to `padding` or poisson if not provided and
1563+
/// permutation map to 'getMinorIdentityMap'.
15631564
OpBuilder<(ins "VectorType":$vectorType,
15641565
"Value":$source,
15651566
"ValueRange":$indices,

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,8 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
12571257
LLVM_DEBUG(permutationMap.print(dbgs()));
12581258

12591259
auto transfer = state.builder.create<vector::TransferReadOp>(
1260-
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, std::nullopt,
1261-
permutationMap);
1260+
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
1261+
/*padding=*/std::nullopt, permutationMap);
12621262

12631263
// Register replacement for future uses in the scope.
12641264
state.registerOpVectorReplacement(loadOp, transfer);

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,15 +1183,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183
auto srcRank = extractOp.getTensor().getType().getRank();
11841184
SmallVector<bool> inBounds(dstRank, true);
11851185

1186+
// Get the value to pad transfer reads with 0.
1187+
Value padding =
1188+
arith::getZeroConstant(rewriter, loc, resultType.getElementType());
1189+
11861190
// 2a. Handle scalar broadcast access.
11871191
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11881192
MLIRContext *ctx = rewriter.getContext();
11891193
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
11901194
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
11911195

11921196
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1193-
loc, resultType, extractOp.getTensor(), transferReadIdxs,
1194-
arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
1197+
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
11951198
permutationMap, inBounds);
11961199

11971200
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1228,8 +1231,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12281231
}
12291232

12301233
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1231-
loc, resultType, extractOp.getTensor(), transferReadIdxs,
1232-
arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
1234+
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
12331235
permutationMap, inBounds);
12341236

12351237
LDBG("Vectorised as contiguous load: " << extractOp);
@@ -1442,7 +1444,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14421444

14431445
Operation *read = rewriter.create<vector::TransferReadOp>(
14441446
loc, readType, opOperand->get(), indices,
1445-
arith::getZeroConstant(rewriter, loc, elemType), readMap);
1447+
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
14461448
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
14471449
Value readValue = read->getResult(0);
14481450

@@ -2644,7 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26442646

26452647
Value readValue = rewriter.create<vector::TransferReadOp>(
26462648
loc, readType, copyOp.getSource(), indices,
2647-
arith::getZeroConstant(rewriter, loc, srcElementType),
2649+
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
26482650
rewriter.getMultiDimIdentityMap(srcType.getRank()));
26492651
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
26502652
readValue =
@@ -3493,16 +3495,16 @@ struct Conv1DGenerator
34933495
// Read the whole lhs, rhs and res in one shot (with zero padding).
34943496
Value lhs = rewriter.create<vector::TransferReadOp>(
34953497
loc, lhsType, lhsShaped, lhsPadding,
3496-
arith::getZeroConstant(rewriter, loc, lhsEltType));
3498+
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
34973499
// This is needed only for Conv.
34983500
Value rhs = nullptr;
34993501
if (oper == ConvOperationKind::Conv)
35003502
rhs = rewriter.create<vector::TransferReadOp>(
35013503
loc, rhsType, rhsShaped, rhsPadding,
3502-
arith::getZeroConstant(rewriter, loc, rhsEltType));
3504+
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
35033505
Value res = rewriter.create<vector::TransferReadOp>(
35043506
loc, resType, resShaped, resPadding,
3505-
arith::getZeroConstant(rewriter, loc, resEltType));
3507+
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
35063508

35073509
// The base vectorization case for channeled convolution is input:
35083510
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3750,21 +3752,21 @@ struct Conv1DGenerator
37503752
// 0].
37513753
Value lhs = rewriter.create<vector::TransferReadOp>(
37523754
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3753-
arith::getZeroConstant(rewriter, loc, lhsEltType));
3755+
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
37543756
auto maybeMaskedLhs = maybeMaskXferOp(
37553757
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
37563758

37573759
// Read rhs slice of size {kw, c} @ [0, 0].
37583760
Value rhs = rewriter.create<vector::TransferReadOp>(
37593761
loc, rhsType, rhsShaped, ValueRange{zero, zero},
3760-
arith::getZeroConstant(rewriter, loc, rhsEltType));
3762+
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
37613763
auto maybeMaskedRhs = maybeMaskXferOp(
37623764
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
37633765

37643766
// Read res slice of size {n, w, c} @ [0, 0, 0].
37653767
Value res = rewriter.create<vector::TransferReadOp>(
37663768
loc, resType, resShaped, ValueRange{zero, zero, zero},
3767-
arith::getZeroConstant(rewriter, loc, resEltType));
3769+
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
37683770
auto maybeMaskedRes = maybeMaskXferOp(
37693771
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
37703772

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
173173
}
174174
SmallVector<bool> inBounds(indices.size(), true);
175175
return b.create<vector::TransferReadOp>(
176-
loc, cast<VectorType>(type), buffer, indices, std::nullopt,
176+
loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
177177
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178178
}
179179

0 commit comments

Comments
 (0)