11import torch
2+ import inspect
23import torch .nn as nn
34import os
45import importlib .util
56from typing import List
67
78
89class ComposedModel (nn .Module ):
9- def __init__ (self , subgraph : List [nn .Module ]):
10+ def __init__ (self , subgraphs : List [nn .Module ]):
1011 super ().__init__ ()
11- self .subgraphs = nn .ModuleList (subgraph )
12+ self .subgraphs = nn .ModuleList (subgraphs )
1213
1314 def forward (self , ** kwargs ):
1415 output = None
1516 for i , subgraph in enumerate (self .subgraphs ):
16- print (f"{ i = } subgraph begin" )
1717 if output is None :
18- output = subgraph (** kwargs )
18+ output = subgraph (** self . _convert_inputs ( subgraph , kwargs ) )
1919 else :
2020 output = subgraph (* output )
21- print (f"{ i = } subgraph end" )
2221
2322 return output
2423
24+ def _convert_inputs (self , subgraph , input_kwargs ):
25+ input_keywords = set (name for name , _ in input_kwargs .items ())
26+ sub_graph_arg_names = set (inspect .signature (subgraph .forward ).parameters )
27+ assert (
28+ len (sub_graph_arg_names - input_keywords ) == 0
29+ ), f"{ (sub_graph_arg_names - input_keywords )= } "
30+ for remainder in input_keywords - sub_graph_arg_names :
31+ assert remainder .startswith ("s" )
32+ assert remainder [1 :].isdigit ()
33+ return {
34+ name : value
35+ for name , value in input_kwargs .items ()
36+ if name in sub_graph_arg_names
37+ }
38+
2539
2640class RangeDecomposerValidatorBackend :
2741 def _load_model_instance (self , path : str , device : str ) -> torch .nn .Module :
@@ -36,40 +50,56 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
3650 instance = ModelClass ().to (device )
3751 return instance
3852
39- def _make_config (self , decomposed_root , decomposed_model_name_suffix = "_decomposed" ):
53+ def _make_config (
54+ self ,
55+ model_path_prefix : str ,
56+ decomposed_root : str ,
57+ decomposed_dentry : str = "_decomposed" ,
58+ ):
4059 return {
60+ "model_path_prefix" : model_path_prefix ,
4161 "decomposed_root" : decomposed_root ,
42- "decomposed_model_name_suffix " : decomposed_model_name_suffix ,
62+ "decomposed_dentry " : decomposed_dentry ,
4363 }
4464
65+ def _get_rel_model_path (self , model_path ) -> str :
66+ model_path = os .path .realpath (model_path )
67+ model_path_prefix = os .path .realpath (self .config ["model_path_prefix" ])
68+ assert model_path .startswith (model_path_prefix )
69+ rel_model_path = model_path [len (model_path_prefix ) :]
70+ if rel_model_path .startswith ("/" ):
71+ rel_model_path = rel_model_path [1 :]
72+ assert not rel_model_path .startswith ("/" )
73+ return rel_model_path
74+
75+ def _get_model_name_order (self , name ):
76+ lst = name .split ("_" )
77+ if not (len (lst ) > 0 ):
78+ return - 1
79+ if not (lst [- 1 ].isdigit ()):
80+ return - 1
81+ return int (lst [- 1 ])
82+
4583 def __call__ (self , model : torch .nn .Module ) -> torch .nn .Module :
4684 config = self ._make_config (** self .config )
47- model_file_path = model .__class__ .__graph_net_file_path__
48- model_dir = os .path .dirname (model_file_path )
49- model_name = os .path .basename (model_dir )
85+ model_path = os .path .dirname (model .__class__ .__graph_net_file_path__ )
86+ rel_model_path = self ._get_rel_model_path (model_path )
5087 decomposed_parent_dir = os .path .join (
51- config ["decomposed_root" ], f" { model_name } _decomposed"
88+ config ["decomposed_root" ], rel_model_path , config [ "decomposed_dentry" ]
5289 )
5390 subgraph_paths = []
54- for name in sorted (os .listdir (decomposed_parent_dir )):
91+ dentries = os .listdir (decomposed_parent_dir )
92+ for name in sorted (dentries , key = self ._get_model_name_order ):
5593 full_path = os .path .join (decomposed_parent_dir , name )
56- if os .path .isdir (full_path ) and name [ - 1 ]. isdigit () :
94+ if os .path .isdir (full_path ) and self . _get_model_name_order ( name ) >= 0 :
5795 subgraph_paths .append (full_path )
5896
59- print (
60- f"[RangeDecomposerValidatorBackend] Found subgraphs: { [os .path .basename (p ) for p in subgraph_paths ]} "
61- )
62-
6397 device = model .__class__ .__graph_net_device__
6498 subgraph_instances = []
6599
66100 for path in subgraph_paths :
67101 instance = self ._load_model_instance (path , device )
68102 subgraph_instances .append (instance )
69- dir_name = os .path .basename (path )
70- print (
71- f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{ dir_name } '"
72- )
73103
74104 composed_model = ComposedModel (subgraph_instances )
75105 return composed_model .eval ()
0 commit comments