@@ -1191,6 +1191,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11911191
11921192 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
11931193 loc, resultType, extractOp.getTensor (), transferReadIdxs,
1194+ arith::getZeroConstant (rewriter, loc, resultType.getElementType ()),
11941195 permutationMap, inBounds);
11951196
11961197 // Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1228,9 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271228 }
12281229
12291230 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1230- loc, resultType, extractOp.getTensor (), transferReadIdxs, permutationMap,
1231- inBounds);
1231+ loc, resultType, extractOp.getTensor (), transferReadIdxs,
1232+ arith::getZeroConstant (rewriter, loc, resultType.getElementType ()),
1233+ permutationMap, inBounds);
12321234
12331235 LDBG (" Vectorised as contiguous load: " << extractOp);
12341236 return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1386,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13841386// / performed to the maximal common vector size implied by the `linalgOp`
13851387// / iteration space. This eager broadcasting is introduced in the
13861388// / permutation_map of the vector.transfer_read operations. The eager
1387- // / broadcasting makes it trivial to detrmine where broadcast, transposes and
1389+ // / broadcasting makes it trivial to determine where broadcast, transposes and
13881390// / reductions should occur, without any bookkeeping. The tradeoff is that, in
13891391// / the absence of good canonicalizations, the amount of work increases.
13901392// / This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1441,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14391441 SmallVector<Value> indices (linalgOp.getShape (opOperand).size (), zero);
14401442
14411443 Operation *read = rewriter.create <vector::TransferReadOp>(
1442- loc, readType, opOperand->get (), indices, readMap);
1444+ loc, readType, opOperand->get (), indices,
1445+ arith::getZeroConstant (rewriter, loc, elemType), readMap);
14431446 read = state.maskOperation (rewriter, read, linalgOp, indexingMap);
14441447 Value readValue = read->getResult (0 );
14451448
@@ -2641,6 +2644,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26412644
26422645 Value readValue = rewriter.create <vector::TransferReadOp>(
26432646 loc, readType, copyOp.getSource (), indices,
2647+ arith::getZeroConstant (rewriter, loc, srcElementType),
26442648 rewriter.getMultiDimIdentityMap (srcType.getRank ()));
26452649 if (cast<VectorType>(readValue.getType ()).getRank () == 0 ) {
26462650 readValue =
@@ -3487,15 +3491,18 @@ struct Conv1DGenerator
34873491 SmallVector<Value> resPadding (resShape.size (), zero);
34883492
34893493 // Read the whole lhs, rhs and res in one shot (with zero padding).
3490- Value lhs = rewriter.create <vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491- lhsPadding);
3494+ Value lhs = rewriter.create <vector::TransferReadOp>(
3495+ loc, lhsType, lhsShaped, lhsPadding,
3496+ arith::getZeroConstant (rewriter, loc, lhsEltType));
34923497 // This is needed only for Conv.
34933498 Value rhs = nullptr ;
34943499 if (oper == ConvOperationKind::Conv)
3495- rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496- rhsPadding);
3497- Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
3498- resPadding);
3500+ rhs = rewriter.create <vector::TransferReadOp>(
3501+ loc, rhsType, rhsShaped, rhsPadding,
3502+ arith::getZeroConstant (rewriter, loc, rhsEltType));
3503+ Value res = rewriter.create <vector::TransferReadOp>(
3504+ loc, resType, resShaped, resPadding,
3505+ arith::getZeroConstant (rewriter, loc, resEltType));
34993506
35003507 // The base vectorization case for channeled convolution is input:
35013508 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3749,22 @@ struct Conv1DGenerator
37423749 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37433750 // 0].
37443751 Value lhs = rewriter.create <vector::TransferReadOp>(
3745- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3752+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3753+ arith::getZeroConstant (rewriter, loc, lhsEltType));
37463754 auto maybeMaskedLhs = maybeMaskXferOp (
37473755 lhsType.getShape (), lhsType.getScalableDims (), lhs.getDefiningOp ());
37483756
37493757 // Read rhs slice of size {kw, c} @ [0, 0].
3750- Value rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751- ValueRange{zero, zero});
3758+ Value rhs = rewriter.create <vector::TransferReadOp>(
3759+ loc, rhsType, rhsShaped, ValueRange{zero, zero},
3760+ arith::getZeroConstant (rewriter, loc, rhsEltType));
37523761 auto maybeMaskedRhs = maybeMaskXferOp (
37533762 rhsType.getShape (), rhsType.getScalableDims (), rhs.getDefiningOp ());
37543763
37553764 // Read res slice of size {n, w, c} @ [0, 0, 0].
37563765 Value res = rewriter.create <vector::TransferReadOp>(
3757- loc, resType, resShaped, ValueRange{zero, zero, zero});
3766+ loc, resType, resShaped, ValueRange{zero, zero, zero},
3767+ arith::getZeroConstant (rewriter, loc, resEltType));
37583768 auto maybeMaskedRes = maybeMaskXferOp (
37593769 resType.getShape (), resType.getScalableDims (), res.getDefiningOp ());
37603770
0 commit comments