@@ -9,9 +9,13 @@ def convert_to_submodules_graph(
99 original_gm : torch .fx .GraphModule ,
1010 split_positions : list [int ],
1111 submodule_hook = None ,
12- submodule_name_prefix = "extraced_submodule" ,
12+ submodule_name_prefix = "extracted_submodule" ,
13+ chain_style = False ,
1314 group_head_and_tail = True ,
1415):
16+ """
17+ chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
18+ """
1519 original_gm = copy .deepcopy (original_gm )
1620 num_placeholders = len (
1721 [node for node in original_gm .graph .nodes if node .op == "placeholder" ]
@@ -68,6 +72,12 @@ def get_end_node_idx(range_idx):
6872 return i + 1
6973 raise NotImplementedError ("Dead code." )
7074
75+ def print_submodule_call (prompt , gm ):
76+ submodule_call_stmts = [
77+ stmt for stmt in gm .code .split ("\n " ) if "self.extracted_submodule" in stmt
78+ ]
79+ print (f"{ prompt } " , submodule_call_stmts )
80+
7181 for range_idx in range (len (range_idx2submodule_body_nodes )):
7282 (
7383 submodule_input_nodes ,
@@ -76,6 +86,7 @@ def get_end_node_idx(range_idx):
7686 original_gm = original_gm ,
7787 start_node_idx = get_start_node_idx (range_idx ),
7888 end_node_idx = get_end_node_idx (range_idx ),
89+ chain_style = chain_style ,
7990 )
8091
8192 def get_input_nodes (range_idx ):
@@ -136,9 +147,14 @@ def get_output_nodes(range_idx):
136147 # Erase old nodes
137148 for node in reversed (get_body_nodes (range_idx )):
138149 original_gm .graph .erase_node (node )
150+ # print_submodule_call("(fx) after Erase old nodes", original_gm)
151+
152+ # print_submodule_call("(fx) before recompile", original_gm)
139153
140154 original_gm .recompile ()
141155
156+ # print_submodule_call("(fx) after recompile", original_gm)
157+
142158 return original_gm
143159
144160
@@ -147,7 +163,7 @@ def fold_range_to_submodule(
147163 start_node_idx : int ,
148164 end_node_idx : int ,
149165 submodule_hook = None ,
150- submodule_name = "extraced_submodule " ,
166+ submodule_name = "extracted_submodule " ,
151167 group_head_and_tail = True ,
152168):
153169 return convert_to_submodules_graph (
@@ -170,6 +186,7 @@ def _get_submodule_inputs_and_outputs(
170186 original_gm : torch .fx .GraphModule ,
171187 start_node_idx : int ,
172188 end_node_idx : int ,
189+ chain_style = False ,
173190):
174191 count_ctx = NodeProducedOrConsumedCountCtx (
175192 defaultdict (int ),
@@ -179,7 +196,11 @@ def _get_submodule_inputs_and_outputs(
179196 node_list = list (original_gm .graph .nodes )
180197
181198 def get_related_node (node ):
182- yield from node .args
199+ for arg in node .args :
200+ if isinstance (arg , tuple ):
201+ yield from arg
202+ else :
203+ yield arg
183204 yield node
184205
185206 for node in node_list [0 :start_node_idx ]:
@@ -194,19 +215,33 @@ def get_related_node(node):
194215 for related_node in get_related_node (node ):
195216 count_ctx .node2after_output [related_node ] += 1
196217
197- input_nodes = [
198- node
199- for node in node_list
200- if count_ctx .node2before_input [node ] > 0
201- if count_ctx .node2body [node ] > 0
202- ]
203-
204- output_nodes = [
205- node
206- for node in node_list
207- if not (count_ctx .node2before_input [node ] > 0 )
208- if count_ctx .node2body [node ] > 0
209- if count_ctx .node2after_output [node ] > 0
210- ]
218+ if chain_style :
219+ input_nodes = [
220+ node
221+ for node in node_list
222+ if (count_ctx .node2before_input [node ] > 0 )
223+ if (count_ctx .node2body [node ] > 0 or count_ctx .node2after_output [node ] > 0 )
224+ ]
225+ input_nodes_set = set (input_nodes )
226+ output_nodes = [
227+ node
228+ for node in node_list
229+ if (count_ctx .node2before_input [node ] > 0 or count_ctx .node2body [node ] > 0 )
230+ if (count_ctx .node2after_output [node ] > 0 )
231+ ]
232+ else :
233+ input_nodes = [
234+ node
235+ for node in node_list
236+ if count_ctx .node2before_input [node ] > 0
237+ if count_ctx .node2body [node ] > 0
238+ ]
239+ output_nodes = [
240+ node
241+ for node in node_list
242+ if not (count_ctx .node2before_input [node ] > 0 )
243+ if count_ctx .node2body [node ] > 0
244+ if count_ctx .node2after_output [node ] > 0
245+ ]
211246
212247 return input_nodes , output_nodes
0 commit comments