11from graph_net .sample_pass .sample_pass import SamplePass
22from graph_net .sample_pass .resumable_sample_pass_mixin import ResumableSamplePassMixin
3+ from graph_net .optional import Optional
34from graph_net .torch .fx_graph_cache_util import (
45 parse_immutable_model_path_into_sole_graph_module ,
56)
67from graph_net .torch .decompose_util import convert_to_submodules_graph
7- from graph_net .torch .count_kernels_util import count_kernels
8+ from graph_net .torch .count_kernels_util import CountNumKernelsNNModule
89from graph_net .torch .fx_graph_module_util import (
910 get_fx_graph_num_ops ,
1011 get_torch_module_and_inputs ,
@@ -41,6 +42,7 @@ def resume(self, rel_model_path: str):
4142 cumsum_num_kernels = analyzer .analyze ()
4243 cumsum_num_kernels_json = json .dumps (cumsum_num_kernels , indent = 4 )
4344 output_dir_path = Path (self .config ["output_dir" ]) / rel_model_path
45+ output_dir_path .mkdir (parents = True , exist_ok = True )
4446 (output_dir_path / self .config ["output_json_file_name" ]).write_text (
4547 cumsum_num_kernels_json
4648 )
@@ -53,9 +55,9 @@ def __init__(self, model_path: Path):
5355 def analyze (self ):
5456 triples = list (self ._get_cumsum_num_kernels ())
5557 data = {
56- "range_and_num_kernels " : [
57- (( start , end ), num_kernels ) for start , end , num_kernels in triples
58- ],
58+ "num_kernels " : [num_kernels for start , end , num_kernels in triples ],
59+ "starts" : [ start for start , end , num_kernels in triples ],
60+ "ends" : [ end for start , end , num_kernels in triples ],
5961 }
6062 return data
6163
@@ -79,16 +81,22 @@ def _get_num_kernels_if_submodule_compiled(
7981 self , graph_module , nn_module , inputs , submodule_start , submodule_end
8082 ):
8183 torch .cuda .empty_cache ()
84+ mut_opt_num_kernels = Optional (None )
85+
86+ def compile_and_count_num_kernels (m , seq_no ):
87+ return CountNumKernelsNNModule (m , mut_opt_num_kernels )
88+
8289 rewrited_gm : torch .fx .GraphModule = convert_to_submodules_graph (
8390 graph_module ,
84- submodule_hook = lambda m , seq_no : torch . compile ( m ) ,
91+ submodule_hook = compile_and_count_num_kernels ,
8592 split_positions = [submodule_start , submodule_end ],
8693 subgraph_ranges = [(submodule_start , submodule_end )],
8794 group_head_and_tail = False ,
8895 chain_style = False ,
8996 )
90- _ , num_kernels = count_kernels (rewrited_gm , inputs )
91- return num_kernels
97+ rewrited_gm (* inputs )
98+ assert mut_opt_num_kernels .is_some ()
99+ return mut_opt_num_kernels .unwrap ()
92100
93101 def _get_ranges (self , gm ):
94102 num_ops = get_fx_graph_num_ops (gm )
0 commit comments