Skip to content

Commit 6a5ac3c

Browse files
committed
[MLIR][Vector] Remove 0-d corner case condition.
1 parent 6db82e0 commit 6a5ac3c

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ struct TransferReadPermutationLowering
9797
matchAndRewriteMaskableOp(vector::TransferReadOp op,
9898
MaskingOpInterface maskOp,
9999
PatternRewriter &rewriter) const override {
100-
// TODO: support 0-d corner case.
101-
if (op.getTransferRank() == 0)
102-
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103100
// TODO: Support transfer_read inside MaskOp case.
104101
if (maskOp)
105102
return rewriter.notifyMatchFailure(op, "Masked case not supported");
@@ -326,9 +323,6 @@ struct TransferOpReduceRank
326323
matchAndRewriteMaskableOp(vector::TransferReadOp op,
327324
MaskingOpInterface maskOp,
328325
PatternRewriter &rewriter) const override {
329-
// TODO: support 0-d corner case.
330-
if (op.getTransferRank() == 0)
331-
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
332326
// TODO: support masked case.
333327
if (maskOp)
334328
return rewriter.notifyMatchFailure(op, "Masked case not supported");
@@ -518,7 +512,7 @@ struct VectorLoadToMemrefLoadLowering
518512
}
519513
};
520514

521-
/// Replace a vector.store with a vector.extractelement + memref.store.
515+
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
522516
struct VectorStoreToMemrefStoreLowering
523517
: public OpRewritePattern<vector::StoreOp> {
524518
using OpRewritePattern::OpRewritePattern;
@@ -530,9 +524,15 @@ struct VectorStoreToMemrefStoreLowering
530524
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
531525

532526
Value extracted;
533-
SmallVector<int64_t> indices(vecType.getRank(), 0);
534-
extracted = rewriter.create<vector::ExtractOp>(
535-
storeOp.getLoc(), storeOp.getValueToStore(), indices);
527+
if (vecType.getRank() == 0) {
528+
// TODO: Unifiy once ExtractOp supports 0-d vectors.
529+
extracted = rewriter.create<vector::ExtractElementOp>(
530+
storeOp.getLoc(), storeOp.getValueToStore());
531+
} else {
532+
SmallVector<int64_t> indices(vecType.getRank(), 0);
533+
extracted = rewriter.create<vector::ExtractOp>(
534+
storeOp.getLoc(), storeOp.getValueToStore(), indices);
535+
}
536536

537537
rewriter.replaceOpWithNewOp<memref::StoreOp>(
538538
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,8 +2971,9 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
29712971
// CHECK-LABEL: func @vector_store_op_0d
29722972
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
29732973
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
2974-
// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
2975-
// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
2974+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
2975+
// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
2976+
// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
29762977

29772978
// -----
29782979

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
77
%f0 = arith.constant 0.0 : f32
88

99
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
10+
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
1011
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
1112

12-
// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref<f32>
13+
// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
14+
// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
1315
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
1416

15-
// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
16-
// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
17+
// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
18+
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
1719
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
1820

1921
return

0 commit comments

Comments
 (0)