@@ -70,7 +70,9 @@ def resume(self, rel_model_path: str):
7070 split_positions = self ._get_split_positions (subgraph_ranges )
7171 rewrited_gm : torch .fx .GraphModule = convert_to_submodules_graph (
7272 gm ,
73- submodule_hook = self .get_naive_decomposer_extractor (rel_model_path ),
73+ submodule_hook = self .get_naive_decomposer_extractor (
74+ rel_model_path , subgraph_ranges
75+ ),
7476 split_positions = split_positions ,
7577 subgraph_ranges = subgraph_ranges ,
7678 group_head_and_tail = self .config .get ("group_head_and_tail" , False ),
@@ -90,13 +92,15 @@ def _get_split_positions(self, subgraph_ranges: list[(int, int)]):
9092 split_positions = [position for pair in subgraph_ranges for position in pair ]
9193 return sorted (set (split_positions ))
9294
93- def get_naive_decomposer_extractor (self , rel_model_path ):
95+ def get_naive_decomposer_extractor (self , rel_model_path , subgraph_ranges ):
9496 def fn (submodule , seq_no ):
9597 return NaiveDecomposerExtractorModule (
9698 config = self .config ,
9799 parent_graph_rel_model_path = rel_model_path ,
98100 submodule = submodule ,
99101 seq_no = seq_no ,
102+ subgraph_start = subgraph_ranges [seq_no ][0 ],
103+ subgraph_end = subgraph_ranges [seq_no ][1 ],
100104 )
101105
102106 return fn
@@ -109,6 +113,8 @@ def __init__(
109113 parent_graph_rel_model_path : str ,
110114 submodule : torch .nn .Module ,
111115 seq_no : int ,
116+ subgraph_start : int ,
117+ subgraph_end : int ,
112118 ):
113119 super ().__init__ ()
114120 self .config = config
@@ -120,7 +126,7 @@ def __init__(
120126 if self .seq_no is None :
121127 self .model_name = parent_graph_model_name
122128 else :
123- submodule_name = f"{ parent_graph_model_name } _{ self .seq_no } "
129+ submodule_name = f"{ parent_graph_model_name } _start { subgraph_start } _end { subgraph_end } _{ self .seq_no } "
124130 self .model_name = submodule_name
125131 self .builtin_extractor = BuiltinGraphExtractor (
126132 name = submodule_name ,
0 commit comments