@@ -17,9 +17,6 @@ def convert_to_submodules_graph(
1717 chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818 """
1919 gm = copy .deepcopy (gm )
20- num_placeholders = len (
21- [node for node in gm .graph .nodes if node .op == "placeholder" ]
22- )
2320 submodules_body_nodes = [
2421 node
2522 for node in gm .graph .nodes
@@ -207,6 +204,32 @@ def _get_submodule_inputs_and_outputs(
207204 start_node_idx : int ,
208205 end_node_idx : int ,
209206 chain_style = False ,
207+ ):
208+ if not chain_style :
209+ (
210+ minimal_input_nodes ,
211+ minimal_output_nodes ,
212+ ) = _get_minimal_submodule_inputs_and_outputs (
213+ gm = gm , start_node_idx = start_node_idx , end_node_idx = end_node_idx
214+ )
215+ return minimal_input_nodes , minimal_output_nodes , []
216+ else :
217+ node_list = list (gm .graph .nodes )
218+ input_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
219+ gm = gm , start_node_idx = start_node_idx , end_node_idx = len (node_list )
220+ )
221+ output_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
222+ gm = gm , start_node_idx = end_node_idx , end_node_idx = len (node_list )
223+ )
224+ identity_nodes_set = set (input_nodes ) & set (output_nodes )
225+ identity_nodes = [node for node in input_nodes if node in identity_nodes_set ]
226+ return input_nodes , output_nodes , identity_nodes
227+
228+
229+ def _get_minimal_submodule_inputs_and_outputs (
230+ gm : torch .fx .GraphModule ,
231+ start_node_idx : int ,
232+ end_node_idx : int ,
210233):
211234 count_ctx = NodeProducedOrConsumedCountCtx (
212235 defaultdict (int ),
@@ -215,33 +238,34 @@ def _get_submodule_inputs_and_outputs(
215238 )
216239 node_list = list (gm .graph .nodes )
217240
218- def _hashable (obj ):
219- if isinstance (obj , slice ):
220- return ("__slice__" , obj .start , obj .stop , obj .step )
221- elif isinstance (obj , (list , tuple )):
222- return tuple (_hashable (x ) for x in obj )
241+ def get_args_node (arg ):
242+ if isinstance (arg , torch .fx .Node ):
243+ yield arg
244+ elif isinstance (arg , (tuple , list )):
245+ for x in arg :
246+ yield from get_args_node (x )
247+ elif isinstance (arg , slice ):
248+ yield arg .start
249+ yield arg .stop
250+ yield arg .step
223251 else :
224- return obj
252+ assert isinstance ( arg , ( int , bool , float , str , type ( None ))), f" { type ( arg ) = } "
225253
226- def get_related_node (node ):
254+ def get_args_node_and_self_node (node ):
227255 for arg in node .args :
228- if isinstance (arg , tuple ):
229- for x in arg :
230- yield _hashable (x )
231- else :
232- yield _hashable (arg )
233- yield _hashable (node )
256+ yield from get_args_node (arg )
257+ yield node
234258
235259 for node in node_list [0 :start_node_idx ]:
236- for related_node in get_related_node (node ):
260+ for related_node in get_args_node_and_self_node (node ):
237261 count_ctx .node2before_input [related_node ] += 1
238262
239263 for node in node_list [start_node_idx :end_node_idx ]:
240- for related_node in get_related_node (node ):
264+ for related_node in get_args_node_and_self_node (node ):
241265 count_ctx .node2body [related_node ] += 1
242266
243267 for node in node_list [end_node_idx :]:
244- for related_node in get_related_node (node ):
268+ for related_node in get_args_node_and_self_node (node ):
245269 count_ctx .node2after_output [related_node ] += 1
246270
247271 input_nodes = [
@@ -257,24 +281,4 @@ def get_related_node(node):
257281 if count_ctx .node2body [node ] > 0
258282 if count_ctx .node2after_output [node ] > 0
259283 ]
260- if not chain_style :
261- identity_nodes = []
262- else :
263- identity_nodes = [
264- node
265- for node in node_list
266- if count_ctx .node2before_input [node ] > 0
267- if count_ctx .node2body [node ] == 0
268- if count_ctx .node2after_output [node ] > 0
269- ]
270- input_nodes_set = set (input_nodes )
271- input_nodes = [
272- * input_nodes ,
273- * [node for node in identity_nodes if node not in input_nodes_set ],
274- ]
275- output_nodes_set = set (output_nodes )
276- output_nodes = [
277- * output_nodes ,
278- * [node for node in identity_nodes if node not in output_nodes_set ],
279- ]
280- return input_nodes , output_nodes , identity_nodes
284+ return input_nodes , output_nodes
0 commit comments