@@ -24,269 +24,11 @@ def encode_config(config: Dict[str, Any]) -> str:
2424 return base64 .b64encode (json_str .encode ("utf-8" )).decode ("utf-8" )
2525
2626
27- class GraphExtractor :
28- def __init__ (self ):
29- self .extract_node = []
30-
31- def _extract_operators_from_graph (
32- self , gm : nn .Module , example_inputs : List [torch .Tensor ] = None
33- ) -> List [Dict [str , Any ]]:
34- operator_list = []
35- named_modules = dict (gm .named_modules ())
36-
37- for node in gm .graph .nodes :
38- if node .op in ("call_method" , "call_function" , "call_module" ):
39- target_name = str (node .target )
40-
41- if node .op == "call_module" :
42- module_instance = named_modules .get (node .target )
43- if module_instance is not None :
44- target_name = type (module_instance ).__name__
45- elif node .op == "call_function" :
46- if isinstance (node .target , Callable ):
47- target_name = node .target .__name__
48- elif node .op == "call_method" :
49- target_name = str (node .target )
50-
51- operator_info = {
52- "op_type" : node .op ,
53- "target" : node .target ,
54- "name" : node .name ,
55- "target_name" : target_name ,
56- }
57- operator_list .append (operator_info )
58-
59- return operator_list
60-
61- def extract_compiler (self , gm : torch .fx .GraphModule , inputs : List [torch .Tensor ]):
62- operator = self ._extract_operators_from_graph (gm , inputs )
63- self .extract_node = operator
64- return gm .forward
65-
66-
67- class ModelLoader :
68- def load_class_from_file (self , model_path : str , device : str ) -> Any :
69- file_path = os .path .join (model_path , "model.py" )
70- file = Path (file_path ).resolve ()
71- module_name = file .stem
72-
73- if not os .path .exists (file_path ):
74- raise FileNotFoundError (f"Model file not found: { file_path } " )
75-
76- with open (file_path , "r" , encoding = "utf-8" ) as f :
77- model_code = f .read ()
78-
79- model_code = graph_utils .modify_code_by_device (model_code , device )
80-
81- spec = importlib .util .spec_from_loader (module_name , loader = None )
82- module = importlib .util .module_from_spec (spec )
83- sys .modules [module_name ] = module
84-
85- compiled_code = compile (model_code , filename = file , mode = "exec" )
86- exec (compiled_code , module .__dict__ )
87-
88- model_class = getattr (module , "GraphModule" , None )
89- if model_class is None :
90- raise ImportError (f"Class 'GraphModule' not found in { file_path } " )
91-
92- return model_class
93-
94- def get_input_dict (self , model_path : str , device : str ) -> Dict [str , torch .Tensor ]:
95- inputs_params = graph_utils .load_converted_from_text (f"{ model_path } " )
96- params = inputs_params ["weight_info" ]
97- for tensor_meta in params .values ():
98- if hasattr (tensor_meta , "device" ):
99- tensor_meta .device = device
100- input_dict = {
101- k : graph_utils .replay_tensor (v ).to (torch .device (device ))
102- for k , v in params .items ()
103- }
104- return input_dict
105-
106-
10727class RangeDecomposerBackend :
10828 def __init__ (self ):
109- self .window_size = 10
11029 self .graph_net_root = Path (graph_net .__file__ ).parent
11130 self .workspace_root = Path .cwd () / "naive_decompose_workspace"
11231
113- def _resolve_token_to_ops (
114- self , tid , num_primitives , token_id2primitive_id , symbol_map
115- ) -> List [str ]:
116- if tid < num_primitives :
117- return [token_id2primitive_id [tid ]]
118- if tid in symbol_map :
119- sub_tokens = symbol_map [tid ].tolist ()
120- ops = []
121- for t in sub_tokens :
122- ops .extend (
123- self ._resolve_token_to_ops (
124- t , num_primitives , token_id2primitive_id , symbol_map
125- )
126- )
127- return ops
128- return [f"Unknown({ tid } )" ]
129-
130- def _extract_ops_via_compile (
131- self , model_path : str , device : str = "cpu"
132- ) -> List [str ]:
133- loader = ModelLoader ()
134- print (f"Loading model from { model_path } on { device } ..." )
135- try :
136- model_class = loader .load_class_from_file (model_path , device )
137- model = model_class ().to (torch .device (device ))
138- model .eval ()
139- input_dict = loader .get_input_dict (model_path , device )
140- except Exception as e :
141- print (f"Error loading/preparing model { model_path } : { e } " )
142- return []
143-
144- extractor = GraphExtractor ()
145- compiled_model = torch .compile (model , backend = extractor .extract_compiler )
146-
147- with torch .no_grad ():
148- compiled_model (** input_dict )
149-
150- ops_info = extractor .extract_node
151- if not ops_info :
152- print (f"Warning: No operators extracted from { model_path } ." )
153- return []
154- return [op ["target_name" ] for op in ops_info ]
155-
156- def _calculate_token_lengths (
157- self , rp_expr , num_primitives , symbol_map
158- ) -> Dict [int , int ]:
159- token2len = {}
160-
161- def get_len (tid ):
162- if tid in token2len :
163- return token2len [tid ]
164- if tid < num_primitives :
165- token2len [tid ] = 1
166- return 1
167- if tid in symbol_map :
168- sub_tokens = symbol_map [tid ].tolist ()
169- length = sum (get_len (t ) for t in sub_tokens )
170- token2len [tid ] = length
171- return length
172- token2len [tid ] = 1
173- return 1
174-
175- for sym_id in rp_expr .symbol_token_ids :
176- get_len (sym_id )
177- return token2len
178-
179- def _analyze_and_get_splits (self , args ) -> Dict [str , Dict ]:
180- input_file = Path (args .model_path )
181- if not input_file .exists ():
182- print (f"Error: Input file { input_file } does not exist." )
183- return {}
184-
185- with open (input_file , "r" ) as f :
186- model_paths = [
187- Path (line .strip ())
188- for line in f
189- if line .strip () and not line .startswith ("#" )
190- ]
191-
192- if not model_paths :
193- print ("No valid model paths found." )
194- return {}
195-
196- inputs_seqs = []
197- valid_models = []
198-
199- for p in model_paths :
200- seq = self ._extract_ops_via_compile (p , args .device )
201- if seq :
202- inputs_seqs .append (seq )
203- valid_models .append ((p .name , p ))
204-
205- if not inputs_seqs :
206- return {}
207-
208- rp_parser = RpExprParser (
209- window_size = self .window_size , fold_policy = "default" , fold_times = 0
210- )
211- rp_expr , token_id2primitive_id = rp_parser (inputs_seqs )
212- rp_expr .try_unwrap_body_of_sole_symbol_token ()
213- rp_expr .try_recursive_inline_symbol_sole_used (token_id2primitive_id )
214- num_primitives = len (token_id2primitive_id )
215- symbol_map = dict (zip (rp_expr .symbol_token_ids , rp_expr .symbol_token_tensors ))
216- token2len = self ._calculate_token_lengths (rp_expr , num_primitives , symbol_map )
217-
218- results = {}
219-
220- for i , (model_name , original_path ) in enumerate (valid_models ):
221- if i >= len (rp_expr .body_rp_expr ):
222- break
223-
224- target_body_tensor = rp_expr .body_rp_expr [i ]
225- seq_tokens = target_body_tensor .tolist ()
226-
227- full_model_ops = []
228- for t in seq_tokens :
229- full_model_ops .extend (
230- self ._resolve_token_to_ops (
231- t , num_primitives , token_id2primitive_id , symbol_map
232- )
233- )
234-
235- current_idx = 0
236- split_points_set = set ()
237- total_len = sum (token2len .get (t , 1 ) for t in seq_tokens )
238-
239- for token_id in seq_tokens :
240- length = token2len .get (token_id , 1 )
241- is_pattern = token_id >= num_primitives
242-
243- if is_pattern :
244- if current_idx > 0 :
245- split_points_set .add (current_idx )
246- end_idx = current_idx + length
247- if end_idx < total_len :
248- split_points_set .add (end_idx )
249-
250- current_idx += length
251-
252- sorted_splits = sorted (list (split_points_set ))
253-
254- self ._print_analysis (
255- model_name , original_path , sorted_splits , total_len , full_model_ops
256- )
257-
258- results [model_name ] = {
259- "path" : str (original_path ),
260- "split_points" : sorted_splits ,
261- }
262-
263- return results
264-
265- def _print_analysis (self , name , path , splits , total_len , full_ops ):
266- print ("=" * 60 )
267- print (f"Model: { name } " )
268- print (f"Path: { path } " )
269- print (f"Splits: { splits } " )
270- print ("-" * 60 )
271-
272- last_split = 0
273- for split in splits + [total_len ]:
274- segment_len = split - last_split
275-
276- start_safe = min (last_split , len (full_ops ))
277- end_safe = min (split , len (full_ops ))
278- segment_ops = full_ops [start_safe :end_safe ]
279-
280- ops_display = str (segment_ops )
281- if len (segment_ops ) > 5 :
282- ops_display = f"[{ segment_ops [0 ]} , ..., { segment_ops [- 1 ]} ]"
283-
284- print (
285- f" Range [{ last_split :3d} , { split :3d} ), Len: { segment_len :3d} | Ops: { ops_display } "
286- )
287- last_split = split
288- print ("\n " )
289-
29032 def __call__ (self , args ):
29133 model_data_map = self ._analyze_and_get_splits (args )
29234
0 commit comments