Skip to content

Commit 596034d

Browse files
committed
[mlir][vector] Tighten the semantics of vector.{load|store}
This change refines the verifier for `vector.load` and `vector.store` to disallow the use of vectors with higher rank than the source or destination memref. For example, the following is now rejected: ```mlir %0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8> vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8> ``` This pattern was previously used in SME end-to-end tests and "happened" to work by implicitly assuming row-major memory layout. However, there is no guarantee that such an assumption will always hold, and we should avoid relying on it unless it can be enforced deterministically. Notably, production ArmSME lowering pipelines do not rely on this behavior. Instead, the expected usage (illustrated here with scalable vector syntax) would be: ```mlir %0 = vector.load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8> ``` This PR updates the verifier accordingly and adjusts all affected tests. These tests are either removed (if no longer relevant) or updated to use memrefs with appropriately matching rank.
1 parent 6e7c40b commit 596034d

File tree

7 files changed

+130
-85
lines changed

7 files changed

+130
-85
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5099,6 +5099,10 @@ LogicalResult vector::LoadOp::verify() {
50995099
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
51005100
return failure();
51015101

5102+
if (memRefTy.getRank() < resVecTy.getRank())
5103+
return emitOpError(
5104+
"destination memref has lower rank than the result vector");
5105+
51025106
// Checks for vector memrefs.
51035107
Type memElemTy = memRefTy.getElementType();
51045108
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
@@ -5131,6 +5135,9 @@ LogicalResult vector::StoreOp::verify() {
51315135
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
51325136
return failure();
51335137

5138+
if (memRefTy.getRank() < valueVecTy.getRank())
5139+
return emitOpError("source memref has lower rank than the vector to store");
5140+
51345141
// Checks for vector memrefs.
51355142
Type memElemTy = memRefTy.getElementType();
51365143
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -718,18 +718,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
718718

719719
// -----
720720

721-
// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
722-
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xi8>)
723-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
724-
// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
725-
func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
726-
%c0 = arith.constant 0 : index
727-
%tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
728-
return %tile : vector<[16]x[16]xi8>
729-
}
730-
731-
// -----
732-
733721
// CHECK-LABEL: @vector_load_i16(
734722
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
735723
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -819,18 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
819819

820820
// -----
821821

822-
func.func @fold_vector_load_subview(
823-
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
824-
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
825-
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
826-
return %1 : vector<12x32xf32>
822+
func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
823+
%off1 : index,
824+
%off2 : index,
825+
%dim1 : index,
826+
%dim2 : index,
827+
%idx : index) -> vector<12x32xf32> {
828+
829+
%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
830+
%1 = vector.load %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<12x32xf32>
831+
return %1 : vector<12x32xf32>
827832
}
828833

829-
// CHECK: func @fold_vector_load_subview
830-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
832-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
833-
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
834+
// CHECK: #[[$ATTR_46:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
835+
// CHECK-LABEL: func.func @fold_vector_load_subview(
836+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
837+
// CHECK-SAME: %[[OFF_1:[a-zA-Z0-9$._-]*]]: index,
838+
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
839+
// CHECK-SAME: %[[DIM_1:[a-zA-Z0-9$._-]*]]: index,
840+
// CHECK-SAME: %[[DIM_2:[a-zA-Z0-9$._-]*]]: index,
841+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index) -> vector<12x32xf32> {
842+
// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_1]], %[[IDX]]]
843+
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_2]], %[[IDX]]]
844+
// CHECK: %[[VAL_8:.*]] = vector.load %[[SRC]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref<24x64xf32>, vector<12x32xf32>
834845

835846
// -----
836847

@@ -851,20 +862,32 @@ func.func @fold_vector_maskedload_subview(
851862

852863
// -----
853864

854-
func.func @fold_vector_store_subview(
855-
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
856-
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
857-
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
858-
return
865+
func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
866+
%off1 : index,
867+
%off2 : index,
868+
%vec: vector<2x32xf32>,
869+
%idx : index,
870+
%dim1 : index,
871+
%dim2 : index) -> () {
872+
873+
%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
874+
vector.store %vec, %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>> , vector<2x32xf32>
875+
return
859876
}
860877

861-
// CHECK: func @fold_vector_store_subview
862-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
864-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
865-
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
866-
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<2x32xf32>
867-
// CHECK: return
878+
// CHECK: #[[$ATTR_47:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
879+
880+
// CHECK-LABEL: func.func @fold_vector_store_subview(
881+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
882+
// CHECK-SAME: %[[OFF1:[a-zA-Z0-9$._-]*]]: index,
883+
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
884+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9$._-]*]]: vector<2x32xf32>,
885+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index,
886+
// CHECK-SAME: %[[VAL_5:[a-zA-Z0-9$._-]*]]: index,
887+
// CHECK-SAME: %[[VAL_6:[a-zA-Z0-9$._-]*]]: index) {
888+
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF1]], %[[IDX]]]
889+
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF_2]], %[[IDX]]]
890+
// CHECK: vector.store %[[VEC]], %[[SRC]]{{\[}}%[[VAL_7]], %[[VAL_8]]] : memref<24x64xf32>, vector<2x32xf32>
868891

869892
// -----
870893

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,13 +1743,11 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
17431743

17441744
// -----
17451745

1746-
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
1746+
func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32>, %rhs : vector<[4]xf32>) {
17471747
%idx = arith.constant 0 : index
1748-
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
1749-
%1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
17501748

17511749
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
1752-
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
1750+
%op = vector.outerproduct %lhs, %rhs : vector<[4]x[4]xf32>, vector<[4]xf32>
17531751
}
17541752

17551753
// -----
@@ -1870,3 +1868,29 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
18701868
: vector<[16]xf32> -> vector<[16]xf32>
18711869
return %0 : vector<[16]xf32>
18721870
}
1871+
1872+
// -----
1873+
1874+
//===----------------------------------------------------------------------===//
1875+
// vector.load
1876+
//===----------------------------------------------------------------------===//
1877+
1878+
func.func @vector_load(%src : memref<?xi8>) {
1879+
%c0 = arith.constant 0 : index
1880+
// expected-error @+1 {{'vector.load' op destination memref has lower rank than the result vector}}
1881+
%0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8>
1882+
return
1883+
}
1884+
1885+
// -----
1886+
1887+
//===----------------------------------------------------------------------===//
1888+
// vector.store
1889+
//===----------------------------------------------------------------------===//
1890+
1891+
func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
1892+
%c0 = arith.constant 0 : index
1893+
// expected-error @+1 {{'vector.store' op source memref has lower rank than the vector to store}}
1894+
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
1895+
return
1896+
}

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
44
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
5-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1xf32>
6-
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
5+
// CHECK-SAME: %[[VEC:.*]]: vector<f32>
6+
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<f32>) {
77
%f0 = arith.constant 0.0 : f32
88

99
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
@@ -12,8 +12,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
1212
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
1313
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
1414

15-
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
16-
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
15+
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<f32>
16+
vector.store %vec, %mem[] : memref<f32>, vector<f32>
1717

1818
return
1919
}

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @entry() {
1717
%za_s_size = arith.muli %svl_s, %svl_s : index
1818

1919
// Allocate memory.
20-
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
20+
%mem1 = memref.alloca(%za_s_size, %svl_s) : memref<?x?xi32>
2121

2222
// Fill each "row" of "mem1" with row number.
2323
//
@@ -29,15 +29,15 @@ func.func @entry() {
2929
// 3, 3, 3, 3
3030
//
3131
%init_0 = arith.constant 0 : i32
32-
scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
32+
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
3333
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
34-
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
34+
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
3535
%val_next = arith.addi %val, %c1_i32 : i32
3636
scf.yield %val_next : i32
3737
}
3838

3939
// Load tile from "mem1".
40-
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
40+
%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
4141

4242
// Transpose tile.
4343
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>

0 commit comments

Comments
 (0)