11import os
22import torch
3- import json
4- import base64
53import shutil
64from typing import Union , Callable
75from graph_net .torch import utils
86from graph_net .torch .decompose_util import convert_to_submodules_graph
97from graph_net .torch .extractor import GraphExtractor as BuiltinGraphExtractor
8+ import graph_net .imp_util as imp_util
109
1110
1211class GraphExtractor :
1312 def __init__ (
1413 self ,
15- config_str : str ,
14+ config : dict ,
1615 name ,
1716 dynamic ,
1817 mut_graph_codes = None ,
@@ -23,14 +22,16 @@ def __init__(
2322 self .dynamic = dynamic
2423 self .mut_graph_codes = mut_graph_codes
2524 self .placeholder_auto_rename = placeholder_auto_rename
26- self .config = self .make_config (** self . convert_to_dict ( config_str ) )
25+ self .config = self .make_config (** config )
2726
2827 def make_config (
2928 self ,
3029 split_positions = (),
3130 group_head_and_tail = False ,
3231 chain_style = False ,
3332 output_dir = "./tmp/naive_decomposer_dir" ,
33+ filter_path = None ,
34+ filter_config = None ,
3435 ):
3536 for pos in split_positions :
3637 assert isinstance (
@@ -41,6 +42,8 @@ def make_config(
4142 "group_head_and_tail" : group_head_and_tail ,
4243 "chain_style" : chain_style ,
4344 "output_dir" : output_dir ,
45+ "filter_path" : filter_path ,
46+ "filter_config" : filter_config if filter_config is not None else {},
4447 }
4548
4649 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
@@ -59,14 +62,6 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
5962 def get_naive_decomposer_extractor (self , submodule , seq_no ):
6063 return NaiveDecomposerExtractor (self , submodule , seq_no )
6164
62- def convert_to_dict (self , config_str ):
63- if config_str is None :
64- return {}
65- config_str = base64 .b64decode (config_str ).decode ("utf-8" )
66- config = json .loads (config_str )
67- assert isinstance (config , dict ), f"config should be a dict. { config_str = } "
68- return config
69-
7065
7166class NaiveDecomposerExtractor (torch .nn .Module ):
7267 def __init__ (self , parent_graph_extractor , submodule , seq_no ):
@@ -83,9 +78,22 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
8378 placeholder_auto_rename = parent_graph_extractor .placeholder_auto_rename ,
8479 workspace_path = self .parent_graph_extractor .config ["output_dir" ],
8580 )
81+ self .filter = self .make_filter (self .parent_graph_extractor .config )
8682
8783 def forward (self , * args ):
8884 if not self .extracted :
89- self .builtin_extractor (self .submodule , args )
85+ if self .need_extract (self .submodule , args ):
86+ self .builtin_extractor (self .submodule , args )
9087 self .extracted = True
9188 return self .submodule (* args )
89+
90+ def need_extract (self , gm , sample_inputs ):
91+ if self .filter is None :
92+ return True
93+ return self .filter (gm , sample_inputs )
94+
95+ def make_filter (self , config ):
96+ if config ["filter_path" ] is None :
97+ return None
98+ module = imp_util .load_module (config ["filter_path" ])
99+ return module .GraphFilter (config ["filter_config" ])
0 commit comments