11import os
22import torch
3- import sys
43import graph_net
4+ import tempfile
5+ from graph_net .torch import constraint_util
56
67
78class GraphExtractor :
@@ -25,11 +26,6 @@ def make_config(
2526 split_positions = (),
2627 group_head_and_tail = False ,
2728 chain_style = False ,
28- output_dir = "./tmp/naive_decomposer_dir" ,
29- filter_path = None ,
30- filter_config = None ,
31- post_extract_process_path = None ,
32- post_extract_process_class_name = None ,
3329 max_step = 8 ,
3430 min_step = 2 ,
3531 max_nodes = 32 ,
@@ -42,40 +38,38 @@ def make_config(
4238 "split_positions" : split_positions ,
4339 "group_head_and_tail" : group_head_and_tail ,
4440 "chain_style" : chain_style ,
45- "output_dir" : output_dir ,
46- "filter_path" : filter_path ,
47- "filter_config" : filter_config if filter_config is not None else {},
48- "post_extract_process_path" : post_extract_process_path ,
49- "post_extract_process_class_name" : post_extract_process_class_name ,
5041 "max_step" : max_step ,
5142 "min_step" : min_step ,
5243 "max_nodes" : max_nodes ,
5344 }
5445
5546 def _get_sub_ranges (self ):
56- kMinLenOps = self . config [ "min_step" ]
57- num_ops = self .config ["max_nodes" ]
58- for length in reversed ( range ( kMinLenOps , num_ops ) ):
59- for start_pos in range (num_ops - length ):
60- end_pos = start_pos + length
47+ for step in reversed (
48+ range ( self .config ["min_step" ], self . config [ "max_step" ] + 1 )
49+ ):
50+ for start_pos in range (self . config [ "max_nodes" ] - step ):
51+ end_pos = start_pos + step
6152 yield start_pos , end_pos
6253
6354 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
64- import json
65- import base64
66-
55+ temp_dir_obj = tempfile .TemporaryDirectory (prefix = "_check_fusable_" )
56+ temp_output_dir = temp_dir_obj .name
57+ found_fusable_subgraph = False
58+ print (f"Using temp output dir: { temp_output_dir } " )
6759 for start_pos , end_pos in self ._get_sub_ranges ():
6860 self .config ["split_positions" ] = [start_pos , end_pos ]
6961 print ("current split_positions:" , self .config ["split_positions" ])
7062 graph_net_root = os .path .dirname (graph_net .__file__ )
71- model_path = f"{ graph_net_root } /../samples//timm/{ self .name } "
63+ model_path = os .path .join (
64+ graph_net_root , ".." , "samples" , "timm" , self .name
65+ )
7266 check_fusable_config = {
7367 "decorator_path" : f"{ graph_net_root } /torch/extractor.py" ,
7468 "decorator_config" : {
7569 "name" : f"{ self .name } " ,
7670 "custom_extractor_path" : f"{ graph_net_root } /torch/naive_graph_decomposer.py" ,
7771 "custom_extractor_config" : {
78- "output_dir" : "/tmp/naive_decompose_workspace" ,
72+ "output_dir" : temp_output_dir ,
7973 "split_positions" : self .config ["split_positions" ],
8074 "group_head_and_tail" : False ,
8175 "filter_path" : f"{ graph_net_root } /torch/naive_subgraph_filter.py" ,
@@ -85,15 +79,20 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
8579 },
8680 },
8781 }
88- json_string = json .dumps (check_fusable_config )
89- json_bytes = json_string .encode ("utf-8" )
90- b64_encoded_bytes = base64 .b64encode (json_bytes )
91- checker_config = b64_encoded_bytes .decode ("utf-8" )
92- cmd = f"{ sys .executable } -m graph_net.torch.run_model --model-path { model_path } --decorator-config '{ checker_config } '"
93- res_code = os .system (cmd )
94- if res_code == 0 :
95- print ("find the biggest fully fusable subgraph" )
82+ success = constraint_util .RunModelPredicator (check_fusable_config )(
83+ model_path
84+ )
85+ if success :
86+ found_fusable_subgraph = True
87+ temp_dir_obj .cleanup = lambda : None
88+ print (
89+ f"SUCCESS in finding the biggest fully fusable subgraph saved in: { temp_output_dir } ."
90+ )
9691 break
9792 else :
93+ print ("Failed attempt. clean up the workspace and continue the search." )
94+ temp_dir_obj .cleanup ()
9895 continue
96+ if not found_fusable_subgraph :
97+ print ("No fusable subgraph found" )
9998 return gm
0 commit comments