33import os
44from pathlib import Path
55from typing import Any , Dict , List
6-
76import torch
87import torch .nn as nn
9- import tempfile
10- import graph_net .imp_util
11- from graph_net .torch import utils as graph_utils
128from graph_net .torch .rp_expr .rp_expr_parser import RpExprParser
9+ from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
10+ from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module_without_varify
1311
1412
1513class TypicalSequenceExtractor :
@@ -28,9 +26,12 @@ def _extract_operators_from_graph(
2826
2927 if node .op == "call_module" :
3028 target_name = type (named_modules [node .target ]).__name__
31- else :
29+ elif node .op == "call_method" :
30+ target_name = f"Tensor.{ node .target } "
31+ elif node .op == "call_function" :
3232 target_name = getattr (node .target , "__name__" , str (node .target ))
33-
33+ else :
34+ raise NotImplementedError ()
3435 operator_list .append (
3536 {
3637 "op_type" : node .op ,
@@ -48,39 +49,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
4849 return gm .forward
4950
5051
51- class TypicalSequenceModelLoader :
52- def load_class_from_file (self , model_path : str , device : str ) -> Any :
53- file_path = os .path .join (model_path , "model.py" )
54-
55- if not os .path .exists (file_path ):
56- raise FileNotFoundError (f"Model file not found: { file_path } " )
57-
58- with open (file_path , "r" , encoding = "utf-8" ) as f :
59- model_code = f .read ()
60- model_code = graph_utils .modify_code_by_device (model_code , device )
61-
62- with tempfile .NamedTemporaryFile (
63- mode = "w" , suffix = ".py" , encoding = "utf-8"
64- ) as temp_file :
65- temp_file .write (model_code )
66- module = graph_net .imp_util .load_module (temp_file .name )
67- model_class = getattr (module , "GraphModule" , None )
68-
69- return model_class
70-
71- def get_input_dict (self , model_path : str , device : str ) -> Dict [str , torch .Tensor ]:
72- inputs_params = graph_utils .load_converted_from_text (f"{ model_path } " )
73- params = inputs_params ["weight_info" ]
74- for tensor_meta in params .values ():
75- if hasattr (tensor_meta , "device" ):
76- tensor_meta .device = device
77- input_dict = {
78- k : graph_utils .replay_tensor (v ).to (torch .device (device ))
79- for k , v in params .items ()
80- }
81- return input_dict
82-
83-
8452class SplitAnalyzer :
8553 def __init__ (
8654 self , window_size : int = 10 , fold_policy : str = "default" , fold_times : int = 0
@@ -109,20 +77,11 @@ def _resolve_token_to_ops(
10977 def _extract_ops_via_compile (
11078 self , model_path : str , device : str = "cpu"
11179 ) -> List [str ]:
112- loader = TypicalSequenceModelLoader ()
113- print (f"Loading model from { model_path } on { device } ..." )
114- try :
115- model_class = loader .load_class_from_file (model_path , device )
116- model = model_class ().to (torch .device (device ))
117- model .eval ()
118- input_dict = loader .get_input_dict (model_path , device )
119- except Exception as e :
120- print (f"Error loading/preparing model { model_path } : { e } " )
121- return []
122-
80+ print (f"extracting ops from { model_path } " )
12381 extractor = TypicalSequenceExtractor ()
124- compiled_model = torch .compile (model , backend = extractor .extract_compiler )
125- compiled_model (** input_dict )
82+ model , inputs = get_torch_module_and_inputs (model_path )
83+ compiled_model , _ = parse_sole_graph_module_without_varify (model , inputs )
84+ extractor .extract_compiler (compiled_model , inputs )
12685 ops_info = extractor .extract_node
12786
12887 return [op ["target_name" ] for op in ops_info ]
@@ -174,7 +133,6 @@ def analyze(
174133
175134 if not inputs_seqs :
176135 return {}
177-
178136 rp_parser = RpExprParser (
179137 window_size = self .window_size ,
180138 fold_policy = self .fold_policy ,
0 commit comments