@@ -1294,10 +1294,6 @@ struct UnrollTransferReadConversion
12941294
12951295 // / Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12961296 // / accesses, and broadcasts and transposes in permutation maps.
1297- // /
1298- // / When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
1299- // / `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
1300- // / MemRef and Tensor source, respectively).
13011297 LogicalResult matchAndRewrite (TransferReadOp xferOp,
13021298 PatternRewriter &rewriter) const override {
13031299 if (xferOp.getVectorType ().getRank () <= options.targetRank )
@@ -1345,32 +1341,20 @@ struct UnrollTransferReadConversion
13451341
13461342 auto inBoundsAttr = dropFirstElem (b, xferOp.getInBoundsAttr ());
13471343
1348- // A value that's read after rank-reducing the original
1349- // vector.transfer_read Op.
1350- Value unpackedReadRes;
1351- if (newXferVecType.getRank () != 0 ) {
1352- // Unpacking Vector that's rank > 2
1353- // (use vector.transfer_read to load a rank-reduced vector)
1354- unpackedReadRes = b.create <vector::TransferReadOp>(
1355- loc, newXferVecType, xferOp.getBase (), xferIndices,
1356- AffineMapAttr::get (unpackedPermutationMap (b, xferOp)),
1357- xferOp.getPadding (), Value (), inBoundsAttr);
1358- maybeAssignMask (b, xferOp,
1359- dyn_cast<vector::TransferReadOp>(
1360- unpackedReadRes.getDefiningOp ()),
1361- i);
1362- } else {
1363- // Unpacking Vector that's rank == 1
1364- // (use memref.load/tensor.extract to load a scalar)
1365- unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase ().getType ())
1366- ? b.create <memref::LoadOp>(
1367- loc, xferOp.getBase (), xferIndices)
1368- .getResult ()
1369- : b.create <tensor::ExtractOp>(
1370- loc, xferOp.getBase (), xferIndices)
1371- .getResult ();
1344+ auto newXferOp = b.create <vector::TransferReadOp>(
1345+ loc, newXferVecType, xferOp.getBase (), xferIndices,
1346+ AffineMapAttr::get (unpackedPermutationMap (b, xferOp)),
1347+ xferOp.getPadding (), Value (), inBoundsAttr);
1348+ maybeAssignMask (b, xferOp, newXferOp, i);
1349+
1350+ Value valToInser = newXferOp.getResult ();
1351+ if (newXferVecType.getRank () == 0 ) {
1352+ // vector.insert does not accept rank-0 as the non-indexed
1353+ // argument. Extract the scalar before inserting.
1354+ valToInser = b.create <vector::ExtractOp>(loc, valToInser,
1355+ SmallVector<int64_t >());
13721356 }
1373- return b.create <vector::InsertOp>(loc, unpackedReadRes , vec,
1357+ return b.create <vector::InsertOp>(loc, valToInser , vec,
13741358 insertionIndices);
13751359 },
13761360 /* outOfBoundsCase=*/
0 commit comments