@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
2424
2525// -----
2626
27+ func.func @fuse_ops_between (%A: f32 , %B: f32 ) -> f32 {
28+ %c2 = arith.constant 2 : index
29+ %c0 = arith.constant 0 : index
30+ %c1 = arith.constant 1 : index
31+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
32+ scf.reduce
33+ }
34+ %res = arith.addf %A , %B : f32
35+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
36+ scf.reduce
37+ }
38+ return %res : f32
39+ }
40+ // CHECK-LABEL: func @fuse_ops_between
41+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
42+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
43+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
44+ // CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47+ // CHECK: scf.reduce
48+ // CHECK: }
49+ // CHECK-NOT: scf.parallel
50+
51+ // -----
52+
2753func.func @fuse_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
2854 %c2 = arith.constant 2 : index
2955 %c0 = arith.constant 0 : index
@@ -89,7 +115,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
89115 memref.store %product_elem , %prod [%i , %j ] : memref <2 x2 xf32 >
90116 scf.reduce
91117 }
92- scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
118+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
93119 %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
94120 %res_elem = arith.addf %A_elem , %c2fp : f32
95121 memref.store %res_elem , %B [%i , %j ] : memref <2 x2 xf32 >
@@ -575,3 +601,215 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
575601// CHECK-NEXT: }
576602// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577603// CHECK-NEXT: return
604+
605+ // -----
606+
607+ func.func @fuse_reductions_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
608+ %c2 = arith.constant 2 : index
609+ %c0 = arith.constant 0 : index
610+ %c1 = arith.constant 1 : index
611+ %init1 = arith.constant 1.0 : f32
612+ %init2 = arith.constant 2.0 : f32
613+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
614+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
615+ scf.reduce (%A_elem : f32 ) {
616+ ^bb0 (%lhs: f32 , %rhs: f32 ):
617+ %1 = arith.addf %lhs , %rhs : f32
618+ scf.reduce.return %1 : f32
619+ }
620+ }
621+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
622+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
623+ scf.reduce (%B_elem : f32 ) {
624+ ^bb0 (%lhs: f32 , %rhs: f32 ):
625+ %1 = arith.mulf %lhs , %rhs : f32
626+ scf.reduce.return %1 : f32
627+ }
628+ }
629+ return %res1 , %res2 : f32 , f32
630+ }
631+
632+ // CHECK-LABEL: func @fuse_reductions_two
633+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
634+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
635+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
636+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
637+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
638+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
639+ // CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
640+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
641+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
642+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
643+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
644+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
645+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
646+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
647+ // CHECK: scf.reduce.return %[[R]] : f32
648+ // CHECK: }
649+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
650+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
651+ // CHECK: scf.reduce.return %[[R]] : f32
652+ // CHECK: }
653+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
654+
655+ // -----
656+
657+ func.func @fuse_reductions_three (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %C: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
658+ %c2 = arith.constant 2 : index
659+ %c0 = arith.constant 0 : index
660+ %c1 = arith.constant 1 : index
661+ %init1 = arith.constant 1.0 : f32
662+ %init2 = arith.constant 2.0 : f32
663+ %init3 = arith.constant 3.0 : f32
664+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
665+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
666+ scf.reduce (%A_elem : f32 ) {
667+ ^bb0 (%lhs: f32 , %rhs: f32 ):
668+ %1 = arith.addf %lhs , %rhs : f32
669+ scf.reduce.return %1 : f32
670+ }
671+ }
672+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
673+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
674+ scf.reduce (%B_elem : f32 ) {
675+ ^bb0 (%lhs: f32 , %rhs: f32 ):
676+ %1 = arith.mulf %lhs , %rhs : f32
677+ scf.reduce.return %1 : f32
678+ }
679+ }
680+ %res3 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init3 ) -> f32 {
681+ %A_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
682+ scf.reduce (%A_elem : f32 ) {
683+ ^bb0 (%lhs: f32 , %rhs: f32 ):
684+ %1 = arith.addf %lhs , %rhs : f32
685+ scf.reduce.return %1 : f32
686+ }
687+ }
688+ return %res1 , %res2 , %res3 : f32 , f32 , f32
689+ }
690+
691+ // CHECK-LABEL: func @fuse_reductions_three
692+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
693+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
694+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
695+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
696+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
697+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
698+ // CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
699+ // CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
700+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
701+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
702+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
703+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
704+ // CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
705+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
706+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
707+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
708+ // CHECK: scf.reduce.return %[[R]] : f32
709+ // CHECK: }
710+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
711+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
712+ // CHECK: scf.reduce.return %[[R]] : f32
713+ // CHECK: }
714+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
715+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
716+ // CHECK: scf.reduce.return %[[R]] : f32
717+ // CHECK: }
718+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
719+
720+ // -----
721+
722+ func.func @reductions_use_res (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
723+ %c2 = arith.constant 2 : index
724+ %c0 = arith.constant 0 : index
725+ %c1 = arith.constant 1 : index
726+ %init1 = arith.constant 1.0 : f32
727+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
728+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
729+ scf.reduce (%A_elem : f32 ) {
730+ ^bb0 (%lhs: f32 , %rhs: f32 ):
731+ %1 = arith.addf %lhs , %rhs : f32
732+ scf.reduce.return %1 : f32
733+ }
734+ }
735+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%res1 ) -> f32 {
736+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
737+ scf.reduce (%B_elem : f32 ) {
738+ ^bb0 (%lhs: f32 , %rhs: f32 ):
739+ %1 = arith.mulf %lhs , %rhs : f32
740+ scf.reduce.return %1 : f32
741+ }
742+ }
743+ return %res1 , %res2 : f32 , f32
744+ }
745+
746+ // %res1 is used as second scf.parallel arg, cannot fuse
747+ // CHECK-LABEL: func @reductions_use_res
748+ // CHECK: scf.parallel
749+ // CHECK: scf.parallel
750+
751+ // -----
752+
753+ func.func @reductions_use_res_inside (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
754+ %c2 = arith.constant 2 : index
755+ %c0 = arith.constant 0 : index
756+ %c1 = arith.constant 1 : index
757+ %init1 = arith.constant 1.0 : f32
758+ %init2 = arith.constant 2.0 : f32
759+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
760+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
761+ scf.reduce (%A_elem : f32 ) {
762+ ^bb0 (%lhs: f32 , %rhs: f32 ):
763+ %1 = arith.addf %lhs , %rhs : f32
764+ scf.reduce.return %1 : f32
765+ }
766+ }
767+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
768+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
769+ %sum = arith.addf %B_elem , %res1 : f32
770+ scf.reduce (%sum : f32 ) {
771+ ^bb0 (%lhs: f32 , %rhs: f32 ):
772+ %1 = arith.mulf %lhs , %rhs : f32
773+ scf.reduce.return %1 : f32
774+ }
775+ }
776+ return %res1 , %res2 : f32 , f32
777+ }
778+
779+ // %res1 is used inside second scf.parallel, cannot fuse
780+ // CHECK-LABEL: func @reductions_use_res_inside
781+ // CHECK: scf.parallel
782+ // CHECK: scf.parallel
783+
784+ // -----
785+
786+ func.func @reductions_use_res_between (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
787+ %c2 = arith.constant 2 : index
788+ %c0 = arith.constant 0 : index
789+ %c1 = arith.constant 1 : index
790+ %init1 = arith.constant 1.0 : f32
791+ %init2 = arith.constant 2.0 : f32
792+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
793+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
794+ scf.reduce (%A_elem : f32 ) {
795+ ^bb0 (%lhs: f32 , %rhs: f32 ):
796+ %1 = arith.addf %lhs , %rhs : f32
797+ scf.reduce.return %1 : f32
798+ }
799+ }
800+ %res3 = arith.addf %res1 , %init2 : f32
801+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
802+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
803+ scf.reduce (%B_elem : f32 ) {
804+ ^bb0 (%lhs: f32 , %rhs: f32 ):
805+ %1 = arith.mulf %lhs , %rhs : f32
806+ scf.reduce.return %1 : f32
807+ }
808+ }
809+ return %res1 , %res2 , %res3 : f32 , f32 , f32
810+ }
811+
812+ // instruction in between the loops uses the first loop result
813+ // CHECK-LABEL: func @reductions_use_res_between
814+ // CHECK: scf.parallel
815+ // CHECK: scf.parallel
0 commit comments