11import os
2+ import torch
23from pathlib import Path
3- import graph_net
44import tempfile
55import shutil
6- from graph_net .torch import fully_fusible_graph_predicator
7- from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
8- from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
6+ from graph_net .torch .graph_decomposer import NaiveDecomposerExtractor
7+ from graph_net .torch .fully_fusible_graph_predicator import (
8+ FullyFusibleSubGraphPredicator ,
9+ )
910import logging
1011
1112logger = logging .getLogger (__name__ )
@@ -19,24 +20,22 @@ def __init__(self, config: dict = None):
1920
2021 def _make_config (
2122 self ,
23+ nn_module_fully_fusible_decorator_path ,
24+ nn_module_fully_fusible_decorator_class_name ,
25+ nn_module_fully_fusible_decorator_config = None ,
2226 output_dir = None ,
23- split_positions = (),
24- group_head_and_tail = False ,
25- chain_style = False ,
27+ resume : bool = True ,
2628 max_step = 8 ,
2729 min_step = 2 ,
2830 max_nodes = 32 ,
2931 model_path_prefix = "" ,
3032 ):
31- for pos in split_positions :
32- assert isinstance (
33- pos , int
34- ), f"split_positions should be list of int, { split_positions = } "
3533 return {
3634 "output_dir" : output_dir ,
37- "split_positions" : split_positions ,
38- "group_head_and_tail" : group_head_and_tail ,
39- "chain_style" : chain_style ,
35+ "resume" : resume ,
36+ "nn_module_fully_fusible_decorator_path" : nn_module_fully_fusible_decorator_path ,
37+ "nn_module_fully_fusible_decorator_class_name" : nn_module_fully_fusible_decorator_class_name ,
38+ "nn_module_fully_fusible_decorator_config" : nn_module_fully_fusible_decorator_config ,
4039 "max_step" : max_step ,
4140 "min_step" : min_step ,
4241 "max_nodes" : max_nodes ,
@@ -61,7 +60,9 @@ def _get_sub_ranges(self):
6160 ), f"Invalid range generated: start={ start_pos } , end={ end_pos } , max={ self .config ['max_nodes' ]} "
6261 yield start_pos , end_pos
6362
64- def _handle_success (self , temp_dir : str , rel_model_path : str ) -> str :
63+ def _copy_from_tmp_dir_to_output_dir (
64+ self , temp_dir : str , rel_model_path : str
65+ ) -> str :
6566 subdirs = list (Path (temp_dir ).iterdir ())
6667 assert len (subdirs ) == 1
6768 temp_dir = str (subdirs [0 ])
@@ -74,57 +75,62 @@ def _handle_success(self, temp_dir: str, rel_model_path: str) -> str:
7475 return target_path
7576
7677 def _build_decompose_config (
77- self , temp_dir : str , start_pos : int , end_pos : int , model_path_prefix
78+ self , temp_dir : str , start_pos : int , end_pos : int
7879 ) -> dict :
79- graph_net_root = os .path .dirname (graph_net .__file__ )
80+ model_path_prefix = self .config ["model_path_prefix" ]
81+ decomposer_config = {
82+ "model_path_prefix" : model_path_prefix ,
83+ "output_dir" : temp_dir ,
84+ "split_positions" : [start_pos , end_pos ],
85+ "group_head_and_tail" : False ,
86+ }
87+ return decomposer_config
8088
81- check_fusible_config = {
82- "handler_path" : f"{ graph_net_root } /torch/graph_decomposer.py" ,
83- "handler_class_name" : "NaiveDecomposerExtractor" ,
84- "handler_config" : {
85- "model_path_prefix" : model_path_prefix ,
86- "output_dir" : temp_dir ,
87- "split_positions" : [start_pos , end_pos ],
88- "group_head_and_tail" : False ,
89- "post_extract_process_path" : f"{ graph_net_root } /torch/post_extract_process_count_kernels.py" ,
90- "post_extract_process_class_name" : "ThrowExitStatusIfGraphFullyFusible" ,
91- },
89+ def _get_fully_fusible_subgraph_predicator (self , model_path ):
90+ config = {
91+ "model_path" : model_path ,
92+ "nn_module_fully_fusible_decorator_path" : self .config [
93+ "nn_module_fully_fusible_decorator_path"
94+ ],
95+ "nn_module_fully_fusible_decorator_class_name" : self .config [
96+ "nn_module_fully_fusible_decorator_class_name"
97+ ],
98+ "nn_module_fully_fusible_decorator_config" : self .config [
99+ "nn_module_fully_fusible_decorator_config"
100+ ],
92101 }
93- return check_fusible_config
102+ return FullyFusibleSubGraphPredicator (config )
103+
104+ def _is_model_path_handled (self , rel_model_path ):
105+ model_path = Path (self .config ["output_dir" ]) / rel_model_path
106+ return model_path .exists () and len (list (model_path .iterdir ())) > 0
94107
95108 def __call__ (self , rel_model_path ):
109+ if self .config ["resume" ] and self ._is_model_path_handled (rel_model_path ):
110+ return
111+ torch .cuda .empty_cache ()
96112 model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
97- module , inputs = get_torch_module_and_inputs (model_path )
98- gm = parse_sole_graph_module (module , inputs )
113+ fully_fusible_subgraph_predicator = self ._get_fully_fusible_subgraph_predicator (
114+ model_path
115+ )
99116 for start_pos , end_pos in self ._get_sub_ranges ():
117+ logger .warning ("fully_fusible_subgraph_predicator-begin" )
118+ success = fully_fusible_subgraph_predicator (start_pos , end_pos )
119+ logger .warning ("fully_fusible_subgraph_predicator-end" )
120+ if not success :
121+ continue
100122 with tempfile .TemporaryDirectory (
101123 prefix = "_find_fusible_subgraph_"
102124 ) as temp_dir :
103- check_fusible_config = self ._build_decompose_config (
104- temp_dir , start_pos , end_pos , self .config ["model_path_prefix" ]
105- )
106- predicator_cls = (
107- fully_fusible_graph_predicator .FullyFusibleGraphPredicator
108- )
109- predicator = predicator_cls (check_fusible_config )
110- logger .warning ("fully_fusible_graph_predicator-begin" )
111- success = predicator (model_path )
112- logger .warning ("fully_fusible_graph_predicator-end" )
113- if not success :
114- continue
115125 decomposer_config = self ._build_decompose_config (
116- temp_dir , start_pos , end_pos , self .config ["model_path_prefix" ]
117- )
118- predicator_cls = (
119- fully_fusible_graph_predicator .FullyFusibleGraphPredicator
126+ temp_dir , start_pos , end_pos
120127 )
121- predicator = predicator_cls (decomposer_config )
122- predicator (model_path )
123- target_path = self ._handle_success (temp_dir , rel_model_path )
124- print (
125- f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: { target_path } "
128+ naive_graph_decomposer = NaiveDecomposerExtractor (decomposer_config )
129+ logger .warning ("naive_graph_decomposer-begin" )
130+ naive_graph_decomposer (rel_model_path )
131+ logger .warning ("naive_graph_decomposer-end" )
132+ fully_fusible_destination_path = self ._copy_from_tmp_dir_to_output_dir (
133+ temp_dir , rel_model_path
126134 )
127- break
128- else :
129- logger .warning ("fail to find fully fusible subgraph" )
130- return gm .forward
135+ print (f"{ fully_fusible_destination_path = } " )
136+ return
0 commit comments