@@ -97,9 +97,6 @@ struct TransferReadPermutationLowering
9797 matchAndRewriteMaskableOp (vector::TransferReadOp op,
9898 MaskingOpInterface maskOp,
9999 PatternRewriter &rewriter) const override {
100- // TODO: support 0-d corner case.
101- if (op.getTransferRank () == 0 )
102- return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
103100 // TODO: Support transfer_read inside MaskOp case.
104101 if (maskOp)
105102 return rewriter.notifyMatchFailure (op, " Masked case not supported" );
@@ -326,9 +323,6 @@ struct TransferOpReduceRank
326323 matchAndRewriteMaskableOp (vector::TransferReadOp op,
327324 MaskingOpInterface maskOp,
328325 PatternRewriter &rewriter) const override {
329- // TODO: support 0-d corner case.
330- if (op.getTransferRank () == 0 )
331- return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
332326 // TODO: support masked case.
333327 if (maskOp)
334328 return rewriter.notifyMatchFailure (op, " Masked case not supported" );
@@ -518,7 +512,7 @@ struct VectorLoadToMemrefLoadLowering
518512 }
519513};
520514
521- // / Replace a vector.store with a vector.extractelement + memref.store.
515+ // / Replace a 0-d vector.store with a vector.extractelement + memref.store.
522516struct VectorStoreToMemrefStoreLowering
523517 : public OpRewritePattern<vector::StoreOp> {
524518 using OpRewritePattern::OpRewritePattern;
@@ -530,9 +524,15 @@ struct VectorStoreToMemrefStoreLowering
530524 return rewriter.notifyMatchFailure (storeOp, " not single element vector" );
531525
532526 Value extracted;
533- SmallVector<int64_t > indices (vecType.getRank (), 0 );
534- extracted = rewriter.create <vector::ExtractOp>(
535- storeOp.getLoc (), storeOp.getValueToStore (), indices);
527+ if (vecType.getRank () == 0 ) {
528+ // TODO: Unifiy once ExtractOp supports 0-d vectors.
529+ extracted = rewriter.create <vector::ExtractElementOp>(
530+ storeOp.getLoc (), storeOp.getValueToStore ());
531+ } else {
532+ SmallVector<int64_t > indices (vecType.getRank (), 0 );
533+ extracted = rewriter.create <vector::ExtractOp>(
534+ storeOp.getLoc (), storeOp.getValueToStore (), indices);
535+ }
536536
537537 rewriter.replaceOpWithNewOp <memref::StoreOp>(
538538 storeOp, extracted, storeOp.getBase (), storeOp.getIndices ());
0 commit comments