@@ -1183,14 +1183,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183 auto srcRank = extractOp.getTensor ().getType ().getRank ();
11841184 SmallVector<bool > inBounds (dstRank, true );
11851185
1186+ // Get the value to pad transfer reads with 0.
1187+ Value padding =
1188+ arith::getZeroConstant (rewriter, loc, resultType.getElementType ());
1189+
11861190 // 2a. Handle scalar broadcast access.
11871191 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11881192 MLIRContext *ctx = rewriter.getContext ();
11891193 SmallVector<AffineExpr> exprs (dstRank, getAffineConstantExpr (0 , ctx));
11901194 auto permutationMap = AffineMap::get (srcRank, 0 , exprs, ctx);
11911195
11921196 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1193- loc, resultType, extractOp.getTensor (), transferReadIdxs,
1197+ loc, resultType, extractOp.getTensor (), transferReadIdxs, padding,
11941198 permutationMap, inBounds);
11951199
11961200 // Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271231 }
12281232
12291233 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1230- loc, resultType, extractOp.getTensor (), transferReadIdxs, permutationMap ,
1231- inBounds);
1234+ loc, resultType, extractOp.getTensor (), transferReadIdxs, padding ,
1235+ permutationMap, inBounds);
12321236
12331237 LDBG (" Vectorised as contiguous load: " << extractOp);
12341238 return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13841388// / performed to the maximal common vector size implied by the `linalgOp`
13851389// / iteration space. This eager broadcasting is introduced in the
13861390// / permutation_map of the vector.transfer_read operations. The eager
1387- // / broadcasting makes it trivial to detrmine where broadcast, transposes and
1391+ // / broadcasting makes it trivial to determine where broadcast, transposes and
13881392// / reductions should occur, without any bookkeeping. The tradeoff is that, in
13891393// / the absence of good canonicalizations, the amount of work increases.
13901394// / This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14391443 SmallVector<Value> indices (linalgOp.getShape (opOperand).size (), zero);
14401444
14411445 Operation *read = rewriter.create <vector::TransferReadOp>(
1442- loc, readType, opOperand->get (), indices, readMap);
1446+ loc, readType, opOperand->get (), indices,
1447+ /* padding=*/ arith::getZeroConstant (rewriter, loc, elemType), readMap);
14431448 read = state.maskOperation (rewriter, read, linalgOp, indexingMap);
14441449 Value readValue = read->getResult (0 );
14451450
@@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26412646
26422647 Value readValue = rewriter.create <vector::TransferReadOp>(
26432648 loc, readType, copyOp.getSource (), indices,
2649+ /* padding=*/ arith::getZeroConstant (rewriter, loc, srcElementType),
26442650 rewriter.getMultiDimIdentityMap (srcType.getRank ()));
26452651 if (cast<VectorType>(readValue.getType ()).getRank () == 0 ) {
26462652 readValue =
@@ -3487,15 +3493,18 @@ struct Conv1DGenerator
34873493 SmallVector<Value> resPadding (resShape.size (), zero);
34883494
34893495 // Read the whole lhs, rhs and res in one shot (with zero padding).
3490- Value lhs = rewriter.create <vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491- lhsPadding);
3496+ Value lhs = rewriter.create <vector::TransferReadOp>(
3497+ loc, lhsType, lhsShaped, lhsPadding,
3498+ /* padding=*/ arith::getZeroConstant (rewriter, loc, lhsEltType));
34923499 // This is needed only for Conv.
34933500 Value rhs = nullptr ;
34943501 if (oper == ConvOperationKind::Conv)
3495- rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496- rhsPadding);
3497- Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
3498- resPadding);
3502+ rhs = rewriter.create <vector::TransferReadOp>(
3503+ loc, rhsType, rhsShaped, rhsPadding,
3504+ /* padding=*/ arith::getZeroConstant (rewriter, loc, rhsEltType));
3505+ Value res = rewriter.create <vector::TransferReadOp>(
3506+ loc, resType, resShaped, resPadding,
3507+ /* padding=*/ arith::getZeroConstant (rewriter, loc, resEltType));
34993508
35003509 // The base vectorization case for channeled convolution is input:
35013510 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3751,22 @@ struct Conv1DGenerator
37423751 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37433752 // 0].
37443753 Value lhs = rewriter.create <vector::TransferReadOp>(
3745- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3754+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3755+ /* padding=*/ arith::getZeroConstant (rewriter, loc, lhsEltType));
37463756 auto maybeMaskedLhs = maybeMaskXferOp (
37473757 lhsType.getShape (), lhsType.getScalableDims (), lhs.getDefiningOp ());
37483758
37493759 // Read rhs slice of size {kw, c} @ [0, 0].
3750- Value rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751- ValueRange{zero, zero});
3760+ Value rhs = rewriter.create <vector::TransferReadOp>(
3761+ loc, rhsType, rhsShaped, ValueRange{zero, zero},
3762+ /* padding=*/ arith::getZeroConstant (rewriter, loc, rhsEltType));
37523763 auto maybeMaskedRhs = maybeMaskXferOp (
37533764 rhsType.getShape (), rhsType.getScalableDims (), rhs.getDefiningOp ());
37543765
37553766 // Read res slice of size {n, w, c} @ [0, 0, 0].
37563767 Value res = rewriter.create <vector::TransferReadOp>(
3757- loc, resType, resShaped, ValueRange{zero, zero, zero});
3768+ loc, resType, resShaped, ValueRange{zero, zero, zero},
3769+ /* padding=*/ arith::getZeroConstant (rewriter, loc, resEltType));
37583770 auto maybeMaskedRes = maybeMaskXferOp (
37593771 resType.getShape (), resType.getScalableDims (), res.getDefiningOp ());
37603772
0 commit comments