Skip to content

Commit 1a5ef00

Browse files
banach-spaceLukacma
authored andcommitted
[mlir][vector][nfc] Update tests for folding mem operations (llvm#164255)
Tests in "fold_maskedload_to_load_all_true_dynamic" excercise folders for: * vector.maskedload, vector.maskedstore, vector.scatter, vector.gather, vector.compressstore, vector.expandload. This patch renames and documents these tests in accordance with: * https://mlir.llvm.org/getting_started/TestingGuide/ Note: the updated tests are referenced in the Test Formatting Best Practices section of the MLIR testing guide: * https://mlir.llvm.org/getting_started/TestingGuide/#test-formatting-best-practices Keeping them aligned with the guidelines ensures consistency and clarity across MLIR’s test suite.
1 parent 76a1b3a commit 1a5ef00

File tree

1 file changed

+106
-78
lines changed

1 file changed

+106
-78
lines changed

mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Lines changed: 106 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,154 @@
11
// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
22

3-
// CHECK-LABEL: func @maskedload0(
4-
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
5-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
6-
// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
7-
// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
8-
// CHECK-NEXT: return %[[T]] : vector<16xf32>
9-
func.func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
3+
//-----------------------------------------------------------------------------
4+
// [Pattern: MaskedLoadFolder]
5+
//-----------------------------------------------------------------------------
6+
7+
// CHECK-LABEL: func @fold_maskedload_all_true_dynamic(
8+
// CHECK-SAME: %[[BASE:.*]]: memref<?xf32>,
9+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
10+
// CHECK-DAG: %[[IDX:.*]] = arith.constant 0 : index
11+
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<?xf32>, vector<16xf32>
12+
// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
13+
func.func @fold_maskedload_all_true_dynamic(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
1014
%c0 = arith.constant 0 : index
1115
%mask = vector.constant_mask [16] : vector<16xi1>
1216
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
1317
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1418
return %ld : vector<16xf32>
1519
}
1620

17-
// CHECK-LABEL: func @maskedload1(
18-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
19-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
20-
// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
21-
// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
22-
// CHECK-NEXT: return %[[T]] : vector<16xf32>
23-
func.func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
21+
// CHECK-LABEL: func @fold_maskedload_all_true_static(
22+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
23+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
24+
// CHECK-DAG: %[[IDX:.*]] = arith.constant 0 : index
25+
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<16xf32>, vector<16xf32>
26+
// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
27+
func.func @fold_maskedload_all_true_static(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
2428
%c0 = arith.constant 0 : index
2529
%mask = vector.constant_mask [16] : vector<16xi1>
2630
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
2731
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2832
return %ld : vector<16xf32>
2933
}
3034

31-
// CHECK-LABEL: func @maskedload2(
32-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
33-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
34-
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
35-
func.func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
35+
// CHECK-LABEL: func @fold_maskedload_all_false_static(
36+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
37+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
38+
// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
39+
func.func @fold_maskedload_all_false_static(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
3640
%c0 = arith.constant 0 : index
3741
%mask = vector.constant_mask [0] : vector<16xi1>
3842
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
3943
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
4044
return %ld : vector<16xf32>
4145
}
4246

43-
// CHECK-LABEL: func @maskedload3(
44-
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
45-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
46-
// CHECK-DAG: %[[C:.*]] = arith.constant 8 : index
47-
// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
48-
// CHECK-NEXT: return %[[T]] : vector<16xf32>
49-
func.func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
47+
// CHECK-LABEL: func @fold_maskedload_dynamic_non_zero_idx(
48+
// CHECK-SAME: %[[BASE:.*]]: memref<?xf32>,
49+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
50+
// CHECK-DAG: %[[IDX:.*]] = arith.constant 8 : index
51+
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<?xf32>, vector<16xf32>
52+
// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
53+
func.func @fold_maskedload_dynamic_non_zero_idx(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
5054
%c8 = arith.constant 8 : index
5155
%mask = vector.constant_mask [16] : vector<16xi1>
5256
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
5357
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
5458
return %ld : vector<16xf32>
5559
}
5660

57-
// CHECK-LABEL: func @maskedstore1(
58-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
59-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
60-
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
61-
// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
61+
//-----------------------------------------------------------------------------
62+
// [Pattern: MaskedStoreFolder]
63+
//-----------------------------------------------------------------------------
64+
65+
// CHECK-LABEL: func @fold_maskedstore_all_true(
66+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
67+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
68+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0 : index
69+
// CHECK-NEXT: vector.store %[[VALUE]], %[[BASE]][%[[IDX]]] : memref<16xf32>, vector<16xf32>
6270
// CHECK-NEXT: return
63-
func.func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
71+
func.func @fold_maskedstore_all_true(%base: memref<16xf32>, %value: vector<16xf32>) {
6472
%c0 = arith.constant 0 : index
6573
%mask = vector.constant_mask [16] : vector<16xi1>
6674
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
6775
return
6876
}
6977

70-
// CHECK-LABEL: func @maskedstore2(
71-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
72-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
78+
// CHECK-LABEL: func @fold_maskedstore_all_false(
79+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
80+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
7381
// CHECK-NEXT: return
74-
func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
82+
func.func @fold_maskedstore_all_false(%base: memref<16xf32>, %value: vector<16xf32>) {
7583
%c0 = arith.constant 0 : index
7684
%mask = vector.constant_mask [0] : vector<16xi1>
7785
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
7886
return
7987
}
8088

81-
// CHECK-LABEL: func @gather1(
82-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
83-
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
84-
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
89+
//-----------------------------------------------------------------------------
90+
// [Pattern: GatherFolder]
91+
//-----------------------------------------------------------------------------
92+
93+
/// There is no alternative (i.e. simpler) Op for this, hence no-fold.
94+
95+
// CHECK-LABEL: func @no_fold_gather_all_true(
96+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
97+
// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
98+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
8599
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
86100
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
87-
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
101+
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[BASE]][%[[C]]] [%[[INDICES]]], %[[M]], %[[PASS_THRU]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
88102
// CHECK-NEXT: return %[[G]] : vector<16xf32>
89-
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
103+
func.func @no_fold_gather_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
90104
%c0 = arith.constant 0 : index
91105
%mask = vector.constant_mask [16] : vector<16xi1>
92106
%ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
93107
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
94108
return %ld : vector<16xf32>
95109
}
96110

97-
// CHECK-LABEL: func @gather2(
98-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
99-
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
100-
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
101-
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
102-
func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
111+
// CHECK-LABEL: func @fold_gather_all_true(
112+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
113+
// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
114+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
115+
// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
116+
func.func @fold_gather_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
103117
%c0 = arith.constant 0 : index
104118
%mask = vector.constant_mask [0] : vector<16xi1>
105119
%ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
106120
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
107121
return %ld : vector<16xf32>
108122
}
109123

110-
// CHECK-LABEL: func @scatter1(
111-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
112-
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
113-
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
124+
//-----------------------------------------------------------------------------
125+
// [Pattern: ScatterFolder]
126+
//-----------------------------------------------------------------------------
127+
128+
/// There is no alternative (i.e. simpler) Op for this, hence no-fold.
129+
130+
// CHECK-LABEL: func @no_fold_scatter_all_true(
131+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
132+
// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
133+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
114134
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
115135
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
116-
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
136+
// CHECK-NEXT: vector.scatter %[[BASE]][%[[C]]] [%[[INDICES]]], %[[M]], %[[VALUE]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
117137
// CHECK-NEXT: return
118-
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
138+
func.func @no_fold_scatter_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
119139
%c0 = arith.constant 0 : index
120140
%mask = vector.constant_mask [16] : vector<16xi1>
121141
vector.scatter %base[%c0][%indices], %mask, %value
122142
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
123143
return
124144
}
125145

126-
// CHECK-LABEL: func @scatter2(
127-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
128-
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
129-
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
146+
// CHECK-LABEL: func @fold_scatter_all_false(
147+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
148+
// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
149+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
130150
// CHECK-NEXT: return
131-
func.func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
151+
func.func @fold_scatter_all_false(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
132152
%c0 = arith.constant 0 : index
133153
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
134154
%mask = vector.constant_mask [0] : vector<16xi1>
@@ -137,50 +157,58 @@ func.func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vec
137157
return
138158
}
139159

140-
// CHECK-LABEL: func @expand1(
141-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
142-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
160+
//-----------------------------------------------------------------------------
161+
// [Pattern: ExpandLoadFolder]
162+
//-----------------------------------------------------------------------------
163+
164+
// CHECK-LABEL: func @fold_expandload_all_true(
165+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
166+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
143167
// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
144-
// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
168+
// CHECK-NEXT: %[[T:.*]] = vector.load %[[BASE]][%[[C]]] : memref<16xf32>, vector<16xf32>
145169
// CHECK-NEXT: return %[[T]] : vector<16xf32>
146-
func.func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
170+
func.func @fold_expandload_all_true(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
147171
%c0 = arith.constant 0 : index
148172
%mask = vector.constant_mask [16] : vector<16xi1>
149173
%ld = vector.expandload %base[%c0], %mask, %pass_thru
150174
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
151175
return %ld : vector<16xf32>
152176
}
153177

154-
// CHECK-LABEL: func @expand2(
155-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
156-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
157-
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
158-
func.func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
178+
// CHECK-LABEL: func @fold_expandload_all_false(
179+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
180+
// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
181+
// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
182+
func.func @fold_expandload_all_false(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
159183
%c0 = arith.constant 0 : index
160184
%mask = vector.constant_mask [0] : vector<16xi1>
161185
%ld = vector.expandload %base[%c0], %mask, %pass_thru
162186
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
163187
return %ld : vector<16xf32>
164188
}
165189

166-
// CHECK-LABEL: func @compress1(
167-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
168-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
190+
//-----------------------------------------------------------------------------
191+
// [Pattern: CompressStoreFolder]
192+
//-----------------------------------------------------------------------------
193+
194+
// CHECK-LABEL: func @fold_compressstore_all_true(
195+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
196+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
169197
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
170-
// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
198+
// CHECK-NEXT: vector.store %[[VALUE]], %[[BASE]][%[[C]]] : memref<16xf32>, vector<16xf32>
171199
// CHECK-NEXT: return
172-
func.func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
200+
func.func @fold_compressstore_all_true(%base: memref<16xf32>, %value: vector<16xf32>) {
173201
%c0 = arith.constant 0 : index
174202
%mask = vector.constant_mask [16] : vector<16xi1>
175203
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
176204
return
177205
}
178206

179-
// CHECK-LABEL: func @compress2(
180-
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
181-
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
207+
// CHECK-LABEL: func @fold_compressstore_all_false(
208+
// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
209+
// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
182210
// CHECK-NEXT: return
183-
func.func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
211+
func.func @fold_compressstore_all_false(%base: memref<16xf32>, %value: vector<16xf32>) {
184212
%c0 = arith.constant 0 : index
185213
%mask = vector.constant_mask [0] : vector<16xi1>
186214
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>

0 commit comments

Comments
 (0)