@@ -16,11 +16,14 @@ func.func @thread_tile_loop() {
1616 gpu.barrier
1717 }
1818 }
19- // The inner loop doesn't always execute once so it cannot be removed.
20- // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C250]] step %[[C250]]
21- // CHECK: gpu.barrier
2219 scf.for %arg3 = %tidy to %c2 step %c2 {
20+ // CHECK-NOT: scf.for
2321 %0 = affine.apply affine_map <()[s0 ] -> (s0 * 4 )>()[%tidx ]
22+ // CHECK: %[[LB:.+]] = affine.apply
23+ // The inner loop doesn't always execute once so it needs an scf.if
24+ // CHECK: %[[COND:.+]] = arith.cmpi slt, %[[LB]], %[[C250]] : index
25+ // CHECK: scf.if %[[COND]] {
26+ // CHECK: gpu.barrier
2427 scf.for %arg4 = %0 to %c250 step %c250 {
2528 gpu.barrier
2629 }
@@ -161,6 +164,7 @@ func.func @delinearize_linearize() {
161164 %c64 = arith.constant 64 : index
162165 %tidx = gpu.thread_id x upper_bound 128
163166 %ids:2 = affine.delinearize_index %tidx into (4 , 32 ) : index , index
167+ // CHECK: %[[IDS:.+]]:2 = affine.delinearize_index
164168 // CHECK-NOT: scf.for
165169 // CHECK: gpu.barrier
166170 scf.for %arg3 = %ids#0 to %c4 step %c4 {
@@ -169,8 +173,9 @@ func.func @delinearize_linearize() {
169173 gpu.barrier
170174 }
171175 }
172- // The loop loop doesn't always execute once so it cannot be removed.
173- // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C3]] step %{{.*}}
176+ // The loop doesn't always execute once so it needs an scf.if
177+ // CHECK: %[[COND:.+]] = arith.cmpi slt, %[[IDS:.+]]#0, %[[C3]] : index
178+ // CHECK: scf.if %[[COND]] {
174179 // CHECK: gpu.barrier
175180 scf.for %arg3 = %ids#0 to %c3 step %c4 {
176181 gpu.barrier
@@ -220,3 +225,91 @@ func.func @argument_with_assume(%arg_index : index) {
220225 }
221226 return
222227}
228+
229+ // -----
230+
231+ func.func @dynamic_ub_unittrip (%arg_index : index , %arg_value : memref <8 xf16 >) {
232+ %c1 = arith.constant 0 : index
233+ %c3 = arith.constant 3 : index
234+ %0 = util.assume.int %arg_index <umin = 0 , umax = 3 > : index
235+ scf.for %arg1 = %c1 to %0 step %c3 {
236+ %alloc = memref.alloc () : memref <4 xf16 >
237+ %subview = memref.subview %arg_value [%arg1 ][4 ][1 ] : memref <8 xf16 > to memref <4 xf16 , strided <[1 ], offset : ?>>
238+ memref.copy %alloc , %subview : memref <4 xf16 > to memref <4 xf16 , strided <[1 ], offset : ?>>
239+ }
240+ return
241+ }
242+ // CHECK-LABEL: func.func @dynamic_ub_unittrip
243+ // CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGVALUE:.+]]: memref<8xf16>)
244+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
245+ // CHECK: %[[UB:.+]] = util.assume.int %[[ARGINDEX]]
246+ // CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]] : index
247+ // CHECK: scf.if %[[COND]] {
248+ // CHECK: memref.alloc()
249+ // CHECK: memref.subview %[[ARGVALUE]][%[[C0]]] [4] [1]
250+ // CHECK: memref.copy
251+
252+ // -----
253+
254+ func.func @dynamic_lb_unittrip (%arg_index : index , %arg_value : memref <8 xf16 >) {
255+ %c1 = arith.constant 1 : index
256+ %c3 = arith.constant 3 : index
257+ %0 = util.assume.int %arg_index <umin = 0 , umax = 50 > : index
258+ scf.for %arg1 = %0 to %c3 step %c3 {
259+ %alloc = memref.alloc () : memref <4 xf16 >
260+ %subview = memref.subview %arg_value [%arg1 ][4 ][1 ] : memref <8 xf16 > to memref <4 xf16 , strided <[1 ], offset : ?>>
261+ memref.copy %alloc , %subview : memref <4 xf16 > to memref <4 xf16 , strided <[1 ], offset : ?>>
262+ }
263+ return
264+ }
265+
266+ // CHECK-LABEL: func.func @dynamic_lb_unittrip
267+ // CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGVALUE:.+]]: memref<8xf16>)
268+ // CHECK: %[[C3:.+]] = arith.constant 3 : index
269+ // CHECK: %[[LB:.+]] = util.assume.int %[[ARGINDEX]]
270+ // CHECK: %[[COND:.+]] = arith.cmpi slt, %[[LB]], %[[C3]] : index
271+ // CHECK: scf.if %[[COND]] {
272+ // CHECK: memref.alloc()
273+ // CHECK: memref.subview %[[ARGVALUE]][%[[LB]]] [4] [1]
274+ // CHECK: memref.copy
275+
276+ // -----
277+
278+ func.func @dynamic_nonunittrip (%arg_index : index , %arg_value : memref <8 xf16 >) {
279+ %c1 = arith.constant 1 : index
280+ %c3 = arith.constant 3 : index
281+ %0 = util.assume.int %arg_index <umin = 0 , umax = 5 > : index
282+ scf.for %arg1 = %c1 to %0 step %c3 {
283+ gpu.barrier
284+ }
285+ return
286+ }
287+ // CHECK-LABEL: func.func @dynamic_nonunittrip
288+ // CHECK: scf.for
289+
290+ // -----
291+
292+ func.func @dynamic_unittrip_with_destination (%arg_index : index , %arg_value : tensor <8 xf16 >) -> tensor <4 xf16 > {
293+ %c0 = arith.constant 0 : index
294+ %c3 = arith.constant 3 : index
295+ %0 = util.assume.int %arg_index <umin = 0 , umax = 3 > : index
296+ %empty = tensor.empty () : tensor <4 xf16 >
297+ %1 = scf.for %arg1 = %c0 to %0 step %c3 iter_args (%arg2 = %empty ) -> (tensor <4 xf16 >) {
298+ %extract = tensor.extract_slice %arg_value [%arg1 ][4 ][1 ] : tensor <8 xf16 > to tensor <4 xf16 >
299+ %2 = arith.negf %extract : tensor <4 xf16 >
300+ scf.yield %2 : tensor <4 xf16 >
301+ }
302+ return %1 : tensor <4 xf16 >
303+ }
304+
305+ // CHECK-LABEL: func.func @dynamic_unittrip_with_destination
306+ // CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGTENSOR:.+]]: tensor<8xf16>)
307+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4xf16>
308+ // CHECK: %[[RESULT:.+]] = scf.if
309+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice
310+ // CHECK: %[[NEG:.+]] = arith.negf %[[SLICE]] : tensor<4xf16>
311+ // CHECK: scf.yield %[[NEG]] : tensor<4xf16>
312+ // CHECK: } else {
313+ // CHECK: scf.yield %[[EMPTY]] : tensor<4xf16>
314+ // CHECK: }
315+ // CHECK: return %[[RESULT]] : tensor<4xf16>
0 commit comments