@@ -1183,15 +1183,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183 auto srcRank = extractOp.getTensor ().getType ().getRank ();
11841184 SmallVector<bool > inBounds (dstRank, true );
11851185
1186+ // Get the value to pad transfer reads with 0.
1187+ Value padding =
1188+ arith::getZeroConstant (rewriter, loc, resultType.getElementType ());
1189+
11861190 // 2a. Handle scalar broadcast access.
11871191 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11881192 MLIRContext *ctx = rewriter.getContext ();
11891193 SmallVector<AffineExpr> exprs (dstRank, getAffineConstantExpr (0 , ctx));
11901194 auto permutationMap = AffineMap::get (srcRank, 0 , exprs, ctx);
11911195
11921196 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1193- loc, resultType, extractOp.getTensor (), transferReadIdxs,
1194- arith::getZeroConstant (rewriter, loc, resultType.getElementType ()),
1197+ loc, resultType, extractOp.getTensor (), transferReadIdxs, padding,
11951198 permutationMap, inBounds);
11961199
11971200 // Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1228,8 +1231,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12281231 }
12291232
12301233 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1231- loc, resultType, extractOp.getTensor (), transferReadIdxs,
1232- arith::getZeroConstant (rewriter, loc, resultType.getElementType ()),
1234+ loc, resultType, extractOp.getTensor (), transferReadIdxs, padding,
12331235 permutationMap, inBounds);
12341236
12351237 LDBG (" Vectorised as contiguous load: " << extractOp);
@@ -1442,7 +1444,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14421444
14431445 Operation *read = rewriter.create <vector::TransferReadOp>(
14441446 loc, readType, opOperand->get (), indices,
1445- arith::getZeroConstant (rewriter, loc, elemType), readMap);
1447+ /* padding= */ arith::getZeroConstant (rewriter, loc, elemType), readMap);
14461448 read = state.maskOperation (rewriter, read, linalgOp, indexingMap);
14471449 Value readValue = read->getResult (0 );
14481450
@@ -2644,7 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26442646
26452647 Value readValue = rewriter.create <vector::TransferReadOp>(
26462648 loc, readType, copyOp.getSource (), indices,
2647- arith::getZeroConstant (rewriter, loc, srcElementType),
2649+ /* padding= */ arith::getZeroConstant (rewriter, loc, srcElementType),
26482650 rewriter.getMultiDimIdentityMap (srcType.getRank ()));
26492651 if (cast<VectorType>(readValue.getType ()).getRank () == 0 ) {
26502652 readValue =
@@ -3493,16 +3495,16 @@ struct Conv1DGenerator
34933495 // Read the whole lhs, rhs and res in one shot (with zero padding).
34943496 Value lhs = rewriter.create <vector::TransferReadOp>(
34953497 loc, lhsType, lhsShaped, lhsPadding,
3496- arith::getZeroConstant (rewriter, loc, lhsEltType));
3498+ /* padding= */ arith::getZeroConstant (rewriter, loc, lhsEltType));
34973499 // This is needed only for Conv.
34983500 Value rhs = nullptr ;
34993501 if (oper == ConvOperationKind::Conv)
35003502 rhs = rewriter.create <vector::TransferReadOp>(
35013503 loc, rhsType, rhsShaped, rhsPadding,
3502- arith::getZeroConstant (rewriter, loc, rhsEltType));
3504+ /* padding= */ arith::getZeroConstant (rewriter, loc, rhsEltType));
35033505 Value res = rewriter.create <vector::TransferReadOp>(
35043506 loc, resType, resShaped, resPadding,
3505- arith::getZeroConstant (rewriter, loc, resEltType));
3507+ /* padding= */ arith::getZeroConstant (rewriter, loc, resEltType));
35063508
35073509 // The base vectorization case for channeled convolution is input:
35083510 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3750,21 +3752,21 @@ struct Conv1DGenerator
37503752 // 0].
37513753 Value lhs = rewriter.create <vector::TransferReadOp>(
37523754 loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3753- arith::getZeroConstant (rewriter, loc, lhsEltType));
3755+ /* padding= */ arith::getZeroConstant (rewriter, loc, lhsEltType));
37543756 auto maybeMaskedLhs = maybeMaskXferOp (
37553757 lhsType.getShape (), lhsType.getScalableDims (), lhs.getDefiningOp ());
37563758
37573759 // Read rhs slice of size {kw, c} @ [0, 0].
37583760 Value rhs = rewriter.create <vector::TransferReadOp>(
37593761 loc, rhsType, rhsShaped, ValueRange{zero, zero},
3760- arith::getZeroConstant (rewriter, loc, rhsEltType));
3762+ /* padding= */ arith::getZeroConstant (rewriter, loc, rhsEltType));
37613763 auto maybeMaskedRhs = maybeMaskXferOp (
37623764 rhsType.getShape (), rhsType.getScalableDims (), rhs.getDefiningOp ());
37633765
37643766 // Read res slice of size {n, w, c} @ [0, 0, 0].
37653767 Value res = rewriter.create <vector::TransferReadOp>(
37663768 loc, resType, resShaped, ValueRange{zero, zero, zero},
3767- arith::getZeroConstant (rewriter, loc, resEltType));
3769+ /* padding= */ arith::getZeroConstant (rewriter, loc, resEltType));
37683770 auto maybeMaskedRes = maybeMaskXferOp (
37693771 resType.getShape (), resType.getScalableDims (), res.getDefiningOp ());
37703772
0 commit comments