33import os
44from pathlib import Path
55from typing import Any , Dict , List
6-
76import torch
87import torch .nn as nn
9- import tempfile
10- import graph_net .imp_util
11- from graph_net .torch import utils as graph_utils
128from graph_net .torch .rp_expr .rp_expr_parser import RpExprParser
9+ from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
10+ from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module_without_varify
1311
1412
1513class TypicalSequenceExtractor :
@@ -28,9 +26,12 @@ def _extract_operators_from_graph(
2826
2927 if node .op == "call_module" :
3028 target_name = type (named_modules [node .target ]).__name__
31- else :
29+ elif node .op == "call_method" :
30+ target_name = f"Tensor.{ node .target } "
31+ elif node .op == "call_function" :
3232 target_name = getattr (node .target , "__name__" , str (node .target ))
33-
33+ else :
34+ raise NotImplementedError ()
3435 operator_list .append (
3536 {
3637 "op_type" : node .op ,
@@ -48,39 +49,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
4849 return gm .forward
4950
5051
51- class TypicalSequenceModelLoader :
52- def load_class_from_file (self , model_path : str , device : str ) -> Any :
53- file_path = os .path .join (model_path , "model.py" )
54-
55- if not os .path .exists (file_path ):
56- raise FileNotFoundError (f"Model file not found: { file_path } " )
57-
58- with open (file_path , "r" , encoding = "utf-8" ) as f :
59- model_code = f .read ()
60- model_code = graph_utils .modify_code_by_device (model_code , device )
61-
62- with tempfile .NamedTemporaryFile (
63- mode = "w" , suffix = ".py" , encoding = "utf-8"
64- ) as temp_file :
65- temp_file .write (model_code )
66- module = graph_net .imp_util .load_module (temp_file .name )
67- model_class = getattr (module , "GraphModule" , None )
68-
69- return model_class
70-
71- def get_input_dict (self , model_path : str , device : str ) -> Dict [str , torch .Tensor ]:
72- inputs_params = graph_utils .load_converted_from_text (f"{ model_path } " )
73- params = inputs_params ["weight_info" ]
74- for tensor_meta in params .values ():
75- if hasattr (tensor_meta , "device" ):
76- tensor_meta .device = device
77- input_dict = {
78- k : graph_utils .replay_tensor (v ).to (torch .device (device ))
79- for k , v in params .items ()
80- }
81- return input_dict
82-
83-
8452class SplitAnalyzer :
8553 def __init__ (
8654 self , window_size : int = 10 , fold_policy : str = "default" , fold_times : int = 0
@@ -109,20 +77,11 @@ def _resolve_token_to_ops(
10977 def _extract_ops_via_compile (
11078 self , model_path : str , device : str = "cpu"
11179 ) -> List [str ]:
112- loader = TypicalSequenceModelLoader ()
113- print (f"Loading model from { model_path } on { device } ..." )
114- try :
115- model_class = loader .load_class_from_file (model_path , device )
116- model = model_class ().to (torch .device (device ))
117- model .eval ()
118- input_dict = loader .get_input_dict (model_path , device )
119- except Exception as e :
120- print (f"Error loading/preparing model { model_path } : { e } " )
121- return []
122-
80+ print (f"extracting ops from { model_path } " )
12381 extractor = TypicalSequenceExtractor ()
124- compiled_model = torch .compile (model , backend = extractor .extract_compiler )
125- compiled_model (** input_dict )
82+ model , inputs = get_torch_module_and_inputs (model_path )
83+ compiled_model , _ = parse_sole_graph_module_without_varify (model , inputs )
84+ extractor .extract_compiler (compiled_model , inputs )
12685 ops_info = extractor .extract_node
12786
12887 return [op ["target_name" ] for op in ops_info ]
@@ -150,11 +109,13 @@ def get_len(tid):
150109 get_len (sym_id )
151110 return token2len
152111
153- def analyze (self , model_paths_file : str , device : str ) -> Dict [str , Dict ]:
112+ def analyze (
113+ self , model_path_prefix : str , model_paths_file : str , device : str
114+ ) -> Dict [str , Dict ]:
154115 input_file = Path (model_paths_file )
155116
156117 with open (input_file , "r" ) as f :
157- model_paths = [
118+ rel_model_paths = [
158119 Path (line .strip ())
159120 for line in f
160121 if line .strip () and not line .startswith ("#" )
@@ -163,15 +124,15 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
163124 inputs_seqs = []
164125 valid_models = []
165126
166- for p in model_paths :
167- seq = self ._extract_ops_via_compile (str (p ), device )
127+ for p in rel_model_paths :
128+ model_full_path = os .path .join (model_path_prefix , p )
129+ seq = self ._extract_ops_via_compile (model_full_path , device )
168130 if seq :
169131 inputs_seqs .append (seq )
170132 valid_models .append ((p .name , p ))
171133
172134 if not inputs_seqs :
173135 return {}
174-
175136 rp_parser = RpExprParser (
176137 window_size = self .window_size ,
177138 fold_policy = self .fold_policy ,
@@ -264,7 +225,7 @@ def main(args):
264225 fold_policy = args .fold_policy ,
265226 fold_times = args .fold_times ,
266227 )
267- results = analyzer .analyze (args .model_list , args .device )
228+ results = analyzer .analyze (args .model_path_prefix , args . model_list , args .device )
268229 if args .output_json :
269230 with open (args .output_json , "w" ) as f :
270231 json .dump (results , f , indent = 4 )
@@ -280,6 +241,12 @@ def main(args):
280241 required = True ,
281242 help = "Path to a text file containing paths to models (one per line)." ,
282243 )
244+ parser .add_argument (
245+ "--model-path-prefix" ,
246+ type = str ,
247+ default = "./" ,
248+ help = "Prefix to add to each model path in the list." ,
249+ )
283250 parser .add_argument (
284251 "--device" ,
285252 type = str ,
0 commit comments