7
7
8
8
torch ._dynamo .config .capture_scalar_outputs = True
9
9
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
10
+ torch ._dynamo .config .capture_sparse_compute = True
11
+ torch ._dynamo .config .raise_on_ctx_manager_usage = False
12
+ torch ._dynamo .config .allow_rnn = True
10
13
11
14
12
15
def extract (name , dynamic = True , mut_graph_codes = None , placeholder_auto_rename = False ):
@@ -84,6 +87,11 @@ def wrapper(model: torch.nn.Module):
84
87
class GraphExtractor :
85
88
def __init__ (self ):
86
89
self .subgraph_counter = 0
90
+ self .workspace_path = os .environ .get ("GRAPH_NET_EXTRACT_WORKSPACE" )
91
+ if not self .workspace_path :
92
+ raise EnvironmentError (
93
+ "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
94
+ )
87
95
88
96
def move_files (self , source_dir , target_dir ):
89
97
os .makedirs (target_dir , exist_ok = True )
@@ -94,13 +102,8 @@ def move_files(self, source_dir, target_dir):
94
102
shutil .move (source_path , target_path )
95
103
96
104
def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
97
- # 1. Get workspace path
98
- workspace_path = os .environ .get ("GRAPH_NET_EXTRACT_WORKSPACE" )
99
- if not workspace_path :
100
- raise EnvironmentError (
101
- "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
102
- )
103
- model_path = os .path .join (workspace_path , name )
105
+ # 1. Get model path
106
+ model_path = os .path .join (self .workspace_path , name )
104
107
os .makedirs (model_path , exist_ok = True )
105
108
106
109
if self .subgraph_counter == 0 :
@@ -140,6 +143,15 @@ def try_rename_placeholder(node):
140
143
input = torch .tensor (4 )
141
144
params [node .target ] = input
142
145
input_idx += 1
146
+
147
+ if node .op == "call_function" and hasattr (node .target , "__name__" ):
148
+ if node .target .__name__ in [
149
+ "_enter_autocast" ,
150
+ "_exit_autocast" ,
151
+ ]:
152
+ node .replace_all_uses_with (node .args [0 ])
153
+ gm .graph .erase_node (node )
154
+
143
155
assert input_idx == len (sample_inputs )
144
156
if mut_graph_codes is not None :
145
157
assert isinstance (mut_graph_codes , list )
0 commit comments