@@ -1183,18 +1183,14 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183 auto srcRank = extractOp.getTensor ().getType ().getRank ();
11841184 SmallVector<bool > inBounds (dstRank, true );
11851185
1186- // Get the value to pad transfer reads with 0.
1187- Value padding =
1188- arith::getZeroConstant (rewriter, loc, resultType.getElementType ());
1189-
11901186 // 2a. Handle scalar broadcast access.
11911187 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11921188 MLIRContext *ctx = rewriter.getContext ();
11931189 SmallVector<AffineExpr> exprs (dstRank, getAffineConstantExpr (0 , ctx));
11941190 auto permutationMap = AffineMap::get (srcRank, 0 , exprs, ctx);
11951191
11961192 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1197- loc, resultType, extractOp.getTensor (), transferReadIdxs, padding,
1193+ loc, resultType, extractOp.getTensor (), transferReadIdxs,
11981194 permutationMap, inBounds);
11991195
12001196 // Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1231,8 +1227,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12311227 }
12321228
12331229 auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1234- loc, resultType, extractOp.getTensor (), transferReadIdxs, padding ,
1235- permutationMap, inBounds);
1230+ loc, resultType, extractOp.getTensor (), transferReadIdxs, permutationMap ,
1231+ inBounds);
12361232
12371233 LDBG (" Vectorised as contiguous load: " << extractOp);
12381234 return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1388,7 +1384,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13881384// / performed to the maximal common vector size implied by the `linalgOp`
13891385// / iteration space. This eager broadcasting is introduced in the
13901386// / permutation_map of the vector.transfer_read operations. The eager
1391- // / broadcasting makes it trivial to determine where broadcast, transposes and
1387+ // / broadcasting makes it trivial to detrmine where broadcast, transposes and
13921388// / reductions should occur, without any bookkeeping. The tradeoff is that, in
13931389// / the absence of good canonicalizations, the amount of work increases.
13941390// / This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1443,8 +1439,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14431439 SmallVector<Value> indices (linalgOp.getShape (opOperand).size (), zero);
14441440
14451441 Operation *read = rewriter.create <vector::TransferReadOp>(
1446- loc, readType, opOperand->get (), indices,
1447- /* padding=*/ arith::getZeroConstant (rewriter, loc, elemType), readMap);
1442+ loc, readType, opOperand->get (), indices, readMap);
14481443 read = state.maskOperation (rewriter, read, linalgOp, indexingMap);
14491444 Value readValue = read->getResult (0 );
14501445
@@ -2646,7 +2641,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26462641
26472642 Value readValue = rewriter.create <vector::TransferReadOp>(
26482643 loc, readType, copyOp.getSource (), indices,
2649- /* padding=*/ arith::getZeroConstant (rewriter, loc, srcElementType),
26502644 rewriter.getMultiDimIdentityMap (srcType.getRank ()));
26512645 if (cast<VectorType>(readValue.getType ()).getRank () == 0 ) {
26522646 readValue =
@@ -3493,18 +3487,15 @@ struct Conv1DGenerator
34933487 SmallVector<Value> resPadding (resShape.size (), zero);
34943488
34953489 // Read the whole lhs, rhs and res in one shot (with zero padding).
3496- Value lhs = rewriter.create <vector::TransferReadOp>(
3497- loc, lhsType, lhsShaped, lhsPadding,
3498- /* padding=*/ arith::getZeroConstant (rewriter, loc, lhsEltType));
3490+ Value lhs = rewriter.create <vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491+ lhsPadding);
34993492 // This is needed only for Conv.
35003493 Value rhs = nullptr ;
35013494 if (oper == ConvOperationKind::Conv)
3502- rhs = rewriter.create <vector::TransferReadOp>(
3503- loc, rhsType, rhsShaped, rhsPadding,
3504- /* padding=*/ arith::getZeroConstant (rewriter, loc, rhsEltType));
3505- Value res = rewriter.create <vector::TransferReadOp>(
3506- loc, resType, resShaped, resPadding,
3507- /* padding=*/ arith::getZeroConstant (rewriter, loc, resEltType));
3495+ rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496+ rhsPadding);
3497+ Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
3498+ resPadding);
35083499
35093500 // The base vectorization case for channeled convolution is input:
35103501 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3751,22 +3742,19 @@ struct Conv1DGenerator
37513742 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37523743 // 0].
37533744 Value lhs = rewriter.create <vector::TransferReadOp>(
3754- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3755- /* padding=*/ arith::getZeroConstant (rewriter, loc, lhsEltType));
3745+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
37563746 auto maybeMaskedLhs = maybeMaskXferOp (
37573747 lhsType.getShape (), lhsType.getScalableDims (), lhs.getDefiningOp ());
37583748
37593749 // Read rhs slice of size {kw, c} @ [0, 0].
3760- Value rhs = rewriter.create <vector::TransferReadOp>(
3761- loc, rhsType, rhsShaped, ValueRange{zero, zero},
3762- /* padding=*/ arith::getZeroConstant (rewriter, loc, rhsEltType));
3750+ Value rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751+ ValueRange{zero, zero});
37633752 auto maybeMaskedRhs = maybeMaskXferOp (
37643753 rhsType.getShape (), rhsType.getScalableDims (), rhs.getDefiningOp ());
37653754
37663755 // Read res slice of size {n, w, c} @ [0, 0, 0].
37673756 Value res = rewriter.create <vector::TransferReadOp>(
3768- loc, resType, resShaped, ValueRange{zero, zero, zero},
3769- /* padding=*/ arith::getZeroConstant (rewriter, loc, resEltType));
3757+ loc, resType, resShaped, ValueRange{zero, zero, zero});
37703758 auto maybeMaskedRes = maybeMaskXferOp (
37713759 resType.getShape (), resType.getScalableDims (), res.getDefiningOp ());
37723760
0 commit comments