@@ -124,6 +124,64 @@ func.func @invariant_affine_if() {
124124
125125// -----
126126
127+ func.func @hoist_invariant_affine_if_success (%lb: index , %ub: index , %step: index ) -> i32 {
128+ %cst_0 = arith.constant 0 : i32
129+ %cst_42 = arith.constant 42 : i32
130+ %sum_result = affine.for %i = %lb to %ub iter_args (%acc = %cst_0 ) -> i32 {
131+ %conditional_add = affine.if affine_set <() : ()> () -> (i32 ) {
132+ %add = arith.addi %cst_42 , %cst_42 : i32
133+ affine.yield %add : i32
134+ } else {
135+ %poison = ub.poison : i32
136+ affine.yield %poison : i32
137+ }
138+ %sum = arith.addi %acc , %conditional_add : i32
139+ affine.yield %sum : i32
140+ }
141+
142+ // CHECK-LABEL: hoist_invariant_affine_if_success
143+ // CHECK-NEXT: arith.constant 0 : i32
144+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
145+ // CHECK-NEXT: %[[IF:.*]] = affine.if
146+ // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
147+ // CHECK: affine.for
148+ // CHECK-NOT: affine.if
149+ // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
150+
151+ return %sum_result : i32
152+ }
153+
154+ // -----
155+
156+ func.func @hoist_variant_affine_if_failure (%lb: index , %ub: index , %step: index ) -> i32 {
157+ %cst_0 = arith.constant 0 : i32
158+ %cst_42 = arith.constant 42 : i32
159+ %ind_7 = arith.constant 7 : index
160+ %sum_result = affine.for %i = %lb to %ub iter_args (%acc = %cst_0 ) -> i32 {
161+ %conditional_add = affine.if affine_set <(d0 , d1 ) : (d1 - d0 >= 0 )> (%i , %ind_7 ) -> (i32 ) {
162+ %add = arith.addi %cst_42 , %cst_42 : i32
163+ affine.yield %add : i32
164+ } else {
165+ %poison = ub.poison : i32
166+ affine.yield %poison : i32
167+ }
168+ %sum = arith.addi %acc , %conditional_add : i32
169+ affine.yield %sum : i32
170+ }
171+
172+ // CHECK-LABEL: hoist_variant_affine_if_failure
173+ // CHECK-NEXT: arith.constant 0 : i32
174+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
175+ // CHECK-NEXT: arith.constant 7 : index
176+ // CHECK-NEXT: affine.for
177+ // CHECK-NEXT: %[[IF:.*]] = affine.if
178+ // CHECK: arith.addi %{{.*}}, %[[IF]]
179+
180+ return %sum_result : i32
181+ }
182+
183+ // -----
184+
127185func.func @hoist_affine_for_with_unknown_trip_count (%lb: index , %ub: index ) {
128186 affine.for %arg0 = 0 to 10 {
129187 affine.for %arg1 = %lb to %ub {
@@ -383,6 +441,69 @@ func.func @parallel_loop_with_invariant() {
383441
384442// -----
385443
444+ func.func @hoist_invariant_scf_if_success (%lb: index , %ub: index , %step: index ) -> i32 {
445+ %cst_0 = arith.constant 0 : i32
446+ %cst_42 = arith.constant 42 : i32
447+ %true = arith.constant true
448+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
449+ %conditional_add = scf.if %true -> (i32 ) {
450+ %add = arith.addi %cst_42 , %cst_42 : i32
451+ scf.yield %add : i32
452+ } else {
453+ %poison = ub.poison : i32
454+ scf.yield %poison : i32
455+ }
456+ %sum = arith.addi %acc , %conditional_add : i32
457+ scf.yield %sum : i32
458+ }
459+
460+ // CHECK-LABEL: hoist_invariant_scf_if_success
461+ // CHECK-NEXT: arith.constant 0 : i32
462+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
463+ // CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
464+ // CHECK-NEXT: %[[IF:.*]] = scf.if %[[TRUE]]
465+ // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
466+ // CHECK: scf.for
467+ // CHECK-NOT: scf.if
468+ // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
469+
470+ return %sum_result : i32
471+ }
472+
473+ // -----
474+
475+ func.func @hoist_variant_scf_if_failure (%lb: index , %ub: index , %step: index ) -> i32 {
476+ %cst_0 = arith.constant 0 : i32
477+ %cst_42 = arith.constant 42 : i32
478+ %ind_7 = arith.constant 7 : index
479+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
480+ %cond = arith.cmpi ult , %i , %ind_7 : index
481+ %conditional_add = scf.if %cond -> (i32 ) {
482+ %add = arith.addi %cst_42 , %cst_42 : i32
483+ scf.yield %add : i32
484+ } else {
485+ %poison = ub.poison : i32
486+ scf.yield %poison : i32
487+ }
488+ %sum = arith.addi %acc , %conditional_add : i32
489+ scf.yield %sum : i32
490+ }
491+
492+ // CHECK-LABEL: hoist_variant_scf_if_failure
493+ // CHECK-NEXT: arith.constant 0 : i32
494+ // CHECK-NEXT: %[[CST_42:.*]] = arith.constant 42 : i32
495+ // CHECK-NEXT: %[[CST_7:.*]] = arith.constant 7 : index
496+ // CHECK-NEXT: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}}
497+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[IV]], %[[CST_7]]
498+ // CHECK-NEXT: %[[IF:.*]] = scf.if %[[CMP]]
499+ // CHECK-NEXT: arith.addi %[[CST_42]], %[[CST_42]] : i32
500+ // CHECK: arith.addi %{{.*}}, %[[IF]]
501+
502+ return %sum_result : i32
503+ }
504+
505+ // -----
506+
386507func.func private @make_val () -> (index )
387508
388509// CHECK-LABEL: func @nested_uses_inside
0 commit comments