@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
565565 transform.yield
566566 }
567567}
568+
569+ // -----
570+
571+ // Test hoisting of vector.extract/vector.broadcast pairs
572+
573+ // CHECK-LABEL: func.func @hoist_vector_broadcasts
574+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
575+ // CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
576+ // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
577+ // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
578+ // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
579+ // CHECK-NEXT: }
580+ // CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
581+ // CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
582+
583+ func.func @hoist_vector_broadcasts (%lb : index , %ub : index , %step : index , %vec : vector <3 x4 xf32 >) -> vector <3 x4 xf32 > {
584+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec ) -> vector <3 x4 xf32 > {
585+ %extract = vector.extract %iarg [0 ] : vector <4 xf32 > from vector <3 x4 xf32 >
586+ %use = " some_use" (%extract ) : (vector <4 xf32 >) -> vector <4 xf32 >
587+ %broadcast = vector.broadcast %use : vector <4 xf32 > to vector <3 x4 xf32 >
588+ scf.yield %broadcast : vector <3 x4 xf32 >
589+ }
590+ return %bcast_vec : vector <3 x4 xf32 >
591+ }
592+
593+ module attributes {transform.with_named_sequence } {
594+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
595+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
596+ : (!transform.any_op ) -> !transform.any_op
597+ transform.structured.hoist_redundant_vector_broadcasts %0
598+ : (!transform.any_op ) -> !transform.any_op
599+ transform.yield
600+ }
601+ }
602+
603+ // -----
604+
605+ // Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
606+
607+ // CHECK-LABEL: func.func @hoist_vector_broadcasts
608+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
609+ // CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
610+ // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
611+ // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
612+ // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
613+ // CHECK-NEXT: }
614+ // CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
615+ // CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
616+
617+ func.func @hoist_vector_broadcasts_dynamic (%lb : index , %ub : index , %step : index , %vec : vector <3 x4 xf32 >, %pos: index ) -> vector <3 x4 xf32 > {
618+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec ) -> vector <3 x4 xf32 > {
619+ %extract = vector.extract %iarg [%pos ] : vector <4 xf32 > from vector <3 x4 xf32 >
620+ %use = " some_use" (%extract ) : (vector <4 xf32 >) -> vector <4 xf32 >
621+ %broadcast = vector.broadcast %use : vector <4 xf32 > to vector <3 x4 xf32 >
622+ scf.yield %broadcast : vector <3 x4 xf32 >
623+ }
624+ return %bcast_vec : vector <3 x4 xf32 >
625+ }
626+
627+ module attributes {transform.with_named_sequence } {
628+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
629+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
630+ : (!transform.any_op ) -> !transform.any_op
631+ transform.structured.hoist_redundant_vector_broadcasts %0
632+ : (!transform.any_op ) -> !transform.any_op
633+ transform.yield
634+ }
635+ }
636+
637+ // -----
638+
639+ // Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
640+
641+ // CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
642+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
643+ // CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
644+ // CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
645+ // CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
646+ // CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
647+ // CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
648+ // CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
649+ // CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
650+ // CHECK-NEXT: }
651+ // CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
652+ // CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
653+ // CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
654+
655+ func.func @hoist_vector_broadcasts_multiple (%lb : index , %ub : index , %step : index , %vec1 : vector <3 x4 xf32 >, %vec2 : vector <3 x5 xf32 >) -> (vector <3 x4 xf32 >, vector <3 x5 xf32 >) {
656+ %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec1 , %iarg2 = %vec2 ) -> (vector <3 x4 xf32 >, vector <3 x5 xf32 >) {
657+ %extract1 = vector.extract %iarg [0 ] : vector <4 xf32 > from vector <3 x4 xf32 >
658+ %extract2 = vector.extract %iarg2 [1 ] : vector <5 xf32 > from vector <3 x5 xf32 >
659+ %use1 = " some_use1" (%extract1 ) : (vector <4 xf32 >) -> vector <4 xf32 >
660+ %use2 = " some_use2" (%extract2 ) : (vector <5 xf32 >) -> vector <5 xf32 >
661+ %broadcast1 = vector.broadcast %use1 : vector <4 xf32 > to vector <3 x4 xf32 >
662+ %broadcast2 = vector.broadcast %use2 : vector <5 xf32 > to vector <3 x5 xf32 >
663+ scf.yield %broadcast1 , %broadcast2 : vector <3 x4 xf32 >,vector <3 x5 xf32 >
664+ }
665+ return %bcast_vec#0 , %bcast_vec#1 : vector <3 x4 xf32 >, vector <3 x5 xf32 >
666+ }
667+
668+ module attributes {transform.with_named_sequence } {
669+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
670+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
671+ : (!transform.any_op ) -> !transform.any_op
672+ transform.structured.hoist_redundant_vector_broadcasts %0
673+ : (!transform.any_op ) -> !transform.any_op
674+ transform.yield
675+ }
676+ }
0 commit comments