44import sys
55import inspect
66import importlib .util
7- from typing import List , Dict
7+ import itertools
8+ from typing import List , Tuple , Dict , Any , Callable
89
910
1011class ComposedModel (nn .Module ):
11- def __init__ (self , submodules : List [nn .Module ]):
12+ def __init__ (self , graph : nn . Module , subgraph : List [nn .Module ]):
1213 super ().__init__ ()
13- self .submodules = nn .ModuleList (submodules )
14- self .submodule_param_names = [
14+ self .graph = graph
15+ self .subgraph = nn .ModuleList (subgraph )
16+ self .subgraph_param_names = [
1517 list (inspect .signature (sm .forward ).parameters .keys ())
16- for sm in self .submodules
18+ for sm in self .subgraph
1719 ]
20+ self .extract_node = []
21+
22+ def _serialize_arg (self , arg : Any ) -> Any :
23+ if isinstance (arg , torch .fx .Node ):
24+ return arg .name
25+ if isinstance (arg , (list , tuple )):
26+ return type (arg )(self ._serialize_arg (elem ) for elem in arg )
27+ if isinstance (arg , dict ):
28+ return {
29+ self ._serialize_arg (k ): self ._serialize_arg (v ) for k , v in arg .items ()
30+ }
31+ return arg
32+
33+ def _extract_operators_from_graph (
34+ self , gm : nn .Module , example_inputs : List [torch .Tensor ] = None
35+ ) -> List [Dict [str , Any ]]:
36+ operator_list = []
37+ for node in gm .graph .nodes :
38+ if node .op in ("call_method" , "call_function" , "call_module" ):
39+ operator_info = {
40+ "op_type" : node .op ,
41+ "target" : node .target ,
42+ "name" : node .name ,
43+ "kwargs" : self ._serialize_arg (node .kwargs ),
44+ }
45+
46+ if isinstance (node .target , Callable ):
47+ try :
48+ operator_info ["target_name" ] = node .target .__name__
49+ except AttributeError :
50+ operator_info ["target_name" ] = str (node .target )
51+ else :
52+ operator_info ["target_name" ] = str (node .target )
53+
54+ operator_list .append (operator_info )
55+
56+ return operator_list
57+
58+ def extract_compiler (self , gm : torch .fx .GraphModule , inputs : List [torch .Tensor ]):
59+ operator = self ._extract_operators_from_graph (gm , inputs )
60+ self .extract_node .append (operator )
61+ return gm .forward
1862
1963 def forward (self , ** kwargs ):
2064 current_args = kwargs
65+ compiled_model = torch .compile (self .graph , backend = self .extract_compiler )
66+ compiled_model (** current_args )
67+ graph_node_list = list (itertools .chain .from_iterable (self .extract_node ))
68+ self .extract_node = []
69+
2170 for i , (sm , param_names ) in enumerate (
22- zip (self .submodules , self .submodule_param_names )
71+ zip (self .subgraph , self .subgraph_param_names )
2372 ):
24- # 准备当前子图的输入字典
2573 call_kwargs = {}
2674 if i > 0 :
27- # 对于后续子图,第一个参数是上一个子图的输出
2875 first_param_name = param_names [0 ]
29- call_kwargs [first_param_name ] = current_args # current_args 此时是上一个子图的输出
76+ call_kwargs [first_param_name ] = current_args
77+ remaining_params = param_names [1 :]
78+ else :
79+ remaining_params = param_names
3080
31- # 从主输入字典中筛选出当前子图需要的权重参数
32- for name in param_names :
33- if name in current_args :
34- call_kwargs [name ] = current_args [name ]
81+ for name in remaining_params :
82+ if name in kwargs :
83+ call_kwargs [name ] = kwargs [name ]
3584
36- outputs = sm ( ** call_kwargs )
37- # 假设每个子图只有一个输出,并且返回的是一个元组
85+ compiled_model = torch . compile ( sm , backend = self . extract_compiler )
86+ outputs = compiled_model ( ** call_kwargs )
3887 current_args = outputs [0 ]
3988
89+ subgraph_node_list = list (itertools .chain .from_iterable (self .extract_node ))
90+ self .extract_node = []
91+
92+ if graph_node_list != subgraph_node_list :
93+ diff_in_graph = [
94+ item for item in graph_node_list if item not in subgraph_node_list
95+ ]
96+ diff_in_subgraph = [
97+ item for item in subgraph_node_list if item not in graph_node_list
98+ ]
99+
100+ error_msg = f"Subgraph segmentation verification failed\n "
101+ error_msg += f"Nodes in graph but not in subgraph: { diff_in_graph } \n "
102+ error_msg += f"Nodes in subgraph but not in graph: { diff_in_subgraph } "
103+ raise ValueError (error_msg )
104+ else :
105+ print ("" )
106+
40107 return (current_args ,)
41108
42109
@@ -54,36 +121,32 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
54121 return instance
55122
56123 def __call__ (self , model : torch .nn .Module ) -> torch .nn .Module :
57- model_file_path = inspect .getfile (
58- model .__class__
59- ) # e.g., /test/simple_CNN/model.py
60- model_dir = os .path .dirname (model_file_path ) # e.g., /test/simple_CNN
61-
62- decomposed_parent_dir = (
63- model_dir + "_decomposed"
64- ) # e.g., /test/simple_CNN_decomposed
124+ model_file_path = model .__class__ .__file_path__
125+ model_dir = os .path .dirname (model_file_path )
126+ decomposed_parent_dir = model_dir + "_decomposed"
65127 subgraph_paths = []
66128 for name in sorted (os .listdir (decomposed_parent_dir )):
67129 full_path = os .path .join (decomposed_parent_dir , name )
68- if os .path .isdir (full_path ) and name . startswith ( "subgraph_" ):
130+ if os .path .isdir (full_path ) and name [ - 1 ]. isdigit ( ):
69131 subgraph_paths .append (full_path )
70132
71133 print (
72134 f"[RangeDecomposerValidatorBackend] Found subgraphs: { [os .path .basename (p ) for p in subgraph_paths ]} "
73135 )
74136
75- submodule_instances = []
76- device = next (model .parameters ()).device # 从传入的model获取device信息
137+ device = model .__class__ .__device__
138+ graph_instances = self ._load_model_instance (model_dir , device )
139+ subgraph_instances = []
77140
78141 for path in subgraph_paths :
79142 instance = self ._load_model_instance (path , device )
80- submodule_instances .append (instance )
143+ subgraph_instances .append (instance )
81144 dir_name = os .path .basename (path )
82145 print (
83146 f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{ dir_name } '"
84147 )
85148
86- composed_model = ComposedModel (submodule_instances )
149+ composed_model = ComposedModel (graph_instances , subgraph_instances )
87150 return composed_model .eval ()
88151
89152 def synchronize (self ):
0 commit comments