11// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22
3- // CHECK-LABEL: func @hoist_vector_transfer_pairs(
4- // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
5- // CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
6- // CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
7- // CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
8- // CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9- // CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
10- // CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
11- // CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
12- // CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
13- // CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
14- // CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
15- func.func @hoist_vector_transfer_pairs (
16- %memref0: memref <?x?xf32 >, %memref1: memref <?x?xf32 >, %memref2: memref <?x?xf32 >,
17- %memref3: memref <?x?xf32 >, %memref4: memref <?x?xf32 >, %memref5: memref <?x?xf32 >,
18- %val: index , %lb : index , %ub : index , %step: index , %cmp: i1 ) {
3+ ///----------------------------------------------------------------------------------------
4+ /// Tests for vector.transfer_read + vector.transfer_write pairs
5+ ///
6+ /// * Nested in double loops
7+ // * Indices depend on induction variables
8+ ///----------------------------------------------------------------------------------------
9+
10+ // CHECK-LABEL: func @mem_use_outside
11+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
12+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
13+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
14+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
15+ func.func @mem_use_outside (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
16+ %pad = arith.constant 0.0 : f32
17+
18+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
19+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
20+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
21+ // CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
22+ // CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
23+ // CHECK: scf.yield %[[USE]] : vector<1xf32>
24+ // CHECK: }
25+ // CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
26+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
27+ // CHECK: }
28+ scf.for %i = %lb to %ub step %step {
29+ scf.for %j = %lb to %ub step %step {
30+ %read = vector.transfer_read %mem [%i , %i ], %pad: memref <?x?xf32 >, vector <1 xf32 >
31+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
32+ vector.transfer_write %use , %mem [%i , %i ] : vector <1 xf32 >, memref <?x?xf32 >
33+ }
34+ }
35+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
36+ return
37+ }
38+
39+ module attributes {transform.with_named_sequence } {
40+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
41+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
42+ : (!transform.any_op ) -> !transform.any_op
43+ transform.structured.hoist_redundant_vector_transfers %0
44+ : (!transform.any_op ) -> !transform.any_op
45+ transform.yield
46+ }
47+ }
48+
49+ // -----
50+
51+ // CHECK-LABEL: func @mem_use_inside_outer_loop
52+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
53+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
54+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
55+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
56+ func.func @mem_use_inside_outer_loop (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
57+ %pad = arith.constant 0.0 : f32
58+
59+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
60+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
61+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
62+ // CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
63+ // CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
64+ // CHECK: scf.yield %[[USE]] : vector<1xf32>
65+ // CHECK: }
66+ // CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
67+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
68+ // CHECK: }
69+ scf.for %i = %lb to %ub step %step {
70+ scf.for %j = %lb to %ub step %step {
71+ %read = vector.transfer_read %mem [%i , %i ], %pad: memref <?x?xf32 >, vector <1 xf32 >
72+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
73+ vector.transfer_write %use , %mem [%i , %i ] : vector <1 xf32 >, memref <?x?xf32 >
74+ }
75+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
76+ }
77+ return
78+ }
79+
80+ module attributes {transform.with_named_sequence } {
81+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
82+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
83+ : (!transform.any_op ) -> !transform.any_op
84+ transform.structured.hoist_redundant_vector_transfers %0
85+ : (!transform.any_op ) -> !transform.any_op
86+ transform.yield
87+ }
88+ }
89+
90+ // -----
91+
92+ ///----------------------------------------------------------------------------------------
93+ /// Tests for vector.transfer_read + vector.transfer_write pairs
94+ ///
95+ /// * Nested in double loops
96+ // * Indices are constant
97+ ///----------------------------------------------------------------------------------------
98+
99+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
100+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
101+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
102+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
103+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
104+ func.func @negative_mem_use_inside_inner_loop_before_write (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
19105 %c0 = arith.constant 0 : index
20- %cst = arith.constant 0.0 : f32
106+ %pad = arith.constant 0.0 : f32
107+
108+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
109+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
110+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
111+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
112+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
113+ // CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
114+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
115+ // CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
116+ // CHECK: }
117+ // CHECK: }
118+ scf.for %i = %lb to %ub step %step {
119+ scf.for %j = %lb to %ub step %step {
120+ %read = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
121+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
122+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
123+ vector.transfer_write %use , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
124+ }
125+ }
126+ return
127+ }
21128
22- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23- // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
24- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25- // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
26- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
27- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
28- // CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
29- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
30- // CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
31- // CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
32- // CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
33- // CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
34- // CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
35- // CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
36- // CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
37- // CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
38- // CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
39- // CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
129+ module attributes {transform.with_named_sequence } {
130+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
131+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
132+ : (!transform.any_op ) -> !transform.any_op
133+ transform.structured.hoist_redundant_vector_transfers %0
134+ : (!transform.any_op ) -> !transform.any_op
135+ transform.yield
136+ }
137+ }
138+
139+ // -----
140+
141+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
142+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
143+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
144+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
145+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
146+ func.func @negative_mem_use_inside_inner_loop_after_write (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
147+ %c0 = arith.constant 0 : index
148+ %pad = arith.constant 0.0 : f32
149+
150+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
151+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
152+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
153+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
154+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
155+ // CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
156+ // CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
157+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
158+ // CHECK: }
159+ // CHECK: }
160+ scf.for %i = %lb to %ub step %step {
161+ scf.for %j = %lb to %ub step %step {
162+ %r3 = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
163+ %u3 = " val_use" (%r3 ) : (vector <1 xf32 >) -> vector <1 xf32 >
164+ vector.transfer_write %u3 , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
165+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
166+ }
167+ }
168+ return
169+ }
170+
171+ module attributes {transform.with_named_sequence } {
172+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
173+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
174+ : (!transform.any_op ) -> !transform.any_op
175+ transform.structured.hoist_redundant_vector_transfers %0
176+ : (!transform.any_op ) -> !transform.any_op
177+ transform.yield
178+ }
179+ }
180+
181+ // -----
182+
183+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
184+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
185+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
186+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
187+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
188+ func.func @negative_mem_use_inside_inner_loop_before_read (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
189+ %c0 = arith.constant 0 : index
190+ %pad = arith.constant 0.0 : f32
191+
192+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
193+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
194+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
195+ // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
196+ // CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
197+ // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
40198// CHECK: }
41- // CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
42- // CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
43- // CHECK: scf.yield {{.*}} : vector<1xf32>
44199// CHECK: }
45- // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
46- // CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
47200 scf.for %i = %lb to %ub step %step {
48201 scf.for %j = %lb to %ub step %step {
49- %r0 = vector.transfer_read %memref1 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <1 xf32 >
50- %r1 = vector.transfer_read %memref0 [%i , %i ], %cst: memref <?x?xf32 >, vector <2 xf32 >
51- %r2 = vector.transfer_read %memref2 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <3 xf32 >
52- %r3 = vector.transfer_read %memref3 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <4 xf32 >
53- " some_crippling_use" (%memref4 ) : (memref <?x?xf32 >) -> ()
54- %r4 = vector.transfer_read %memref4 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <5 xf32 >
55- %r5 = vector.transfer_read %memref5 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <6 xf32 >
56- " some_crippling_use" (%memref5 ) : (memref <?x?xf32 >) -> ()
57- %u0 = " some_use" (%r0 ) : (vector <1 xf32 >) -> vector <1 xf32 >
58- %u1 = " some_use" (%r1 ) : (vector <2 xf32 >) -> vector <2 xf32 >
59- %u2 = " some_use" (%memref2 , %r2 ) : (memref <?x?xf32 >, vector <3 xf32 >) -> vector <3 xf32 >
60- %u3 = " some_use" (%r3 ) : (vector <4 xf32 >) -> vector <4 xf32 >
61- %u4 = " some_use" (%r4 ) : (vector <5 xf32 >) -> vector <5 xf32 >
62- %u5 = " some_use" (%r5 ) : (vector <6 xf32 >) -> vector <6 xf32 >
63- vector.transfer_write %u0 , %memref1 [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
64- vector.transfer_write %u1 , %memref0 [%i , %i ] : vector <2 xf32 >, memref <?x?xf32 >
65- vector.transfer_write %u2 , %memref2 [%c0 , %c0 ] : vector <3 xf32 >, memref <?x?xf32 >
66- vector.transfer_write %u3 , %memref3 [%c0 , %c0 ] : vector <4 xf32 >, memref <?x?xf32 >
67- vector.transfer_write %u4 , %memref4 [%c0 , %c0 ] : vector <5 xf32 >, memref <?x?xf32 >
68- vector.transfer_write %u5 , %memref5 [%c0 , %c0 ] : vector <6 xf32 >, memref <?x?xf32 >
69- " some_crippling_use" (%memref3 ) : (memref <?x?xf32 >) -> ()
202+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
203+ %read = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
204+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
205+ vector.transfer_write %use , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
70206 }
71- " unrelated_use" (%memref0 ) : (memref <?x?xf32 >) -> ()
72207 }
73- " unrelated_use" (%memref1 ) : (memref <?x?xf32 >) -> ()
74208 return
75209}
76210
@@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} {
86220
87221// -----
88222
223+ ///----------------------------------------------------------------------------------------
224+ /// Other tests
225+ ///
226+ /// TODO: Document
227+ ///----------------------------------------------------------------------------------------
228+
89229// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
90230// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
91231// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
0 commit comments