1717 Partitioner ,
1818 PartitionResult ,
1919)
20+ from torch .fx import GraphModule
2021from torch .fx .passes .infra .partitioner import Partition
22+ from executorch .exir .graph_module import get_control_flow_submodules
2123
2224
2325def format_target_name (target_name : str ) -> str :
@@ -160,11 +162,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
160162 return (do_not_decomp , filter_fn )
161163
162164 def get_matched_nodes_from_configs (
163- self , ep : ExportedProgram
165+ self , ep : ExportedProgram , gm : Optional [ GraphModule ] = None
164166 ) -> List [List [torch .fx .Node ]]:
165167 # gather supported nodes
166168 matched_nodes = []
167- gm = ep .graph_module
169+ gm = gm or ep .graph_module
168170 for node in gm .graph .nodes :
169171 if node .op == "call_function" :
170172 target = format_target_name (node .target .__name__ )
@@ -175,17 +177,19 @@ def get_matched_nodes_from_configs(
175177
176178 return matched_nodes
177179
178- def generate_partitions (self , ep : ExportedProgram ) -> List [Partition ]:
179- matched_nodes = self .get_matched_nodes_from_configs (ep )
180+ def generate_partitions (self , ep : ExportedProgram , gm : Optional [GraphModule ] = None ) -> List [Partition ]:
181+ gm = gm or ep .graph_module
182+ matched_nodes = self .get_matched_nodes_from_configs (ep , gm )
180183 # create partitions
181184 partitions = generate_partitions_from_list_of_nodes (
182- ep . graph_module ,
185+ gm ,
183186 matched_nodes ,
184187 )
185188 return partitions
186189
187- def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
188- partitions = self .generate_partitions (exported_program )
190+ def partition (self , exported_program : ExportedProgram , graph_module : Optional [GraphModule ] = None ) -> PartitionResult :
191+ graph_module = graph_module or exported_program .graph_module
192+ partitions = self .generate_partitions (exported_program , graph_module )
189193
190194 # tag nodes
191195 partition_tags : Dict [str , DelegationSpec ] = {}
@@ -199,6 +203,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
199203 node .meta ["delegation_tag" ] = delegation_tag
200204 partition_tags [delegation_tag ] = self .delegation_spec
201205
206+ for _ , submodule , _ in get_control_flow_submodules (graph_module ):
207+ # pyre-ignore
208+ self .partition (exported_program , submodule )
209+
202210 return PartitionResult (
203211 tagged_exported_program = exported_program , partition_tags = partition_tags
204212 )
0 commit comments