@@ -24,6 +24,7 @@ def declare_config(
2424 model_path_prefix : str ,
2525 output_dir : str ,
2626 resume : bool = False ,
27+ device : str = "auto" ,
2728 start_offset_in_original_graph : int = 0 ,
2829 limits_handled_models : int = None ,
2930 output_json_file_name : str = "cumsum_num_kernels.json" ,
@@ -39,19 +40,30 @@ def sample_handled(self, rel_model_path: str) -> bool:
3940
4041 def resume (self , rel_model_path : str ):
4142 model_path = Path (self .config ["model_path_prefix" ]) / rel_model_path
43+ device = self ._choose_device (self .config ["device" ])
4244 start_offset_in_original_graph = self .config ["start_offset_in_original_graph" ]
43- analyzer = CumsumNumKernelsAnalyzer (model_path , start_offset_in_original_graph )
45+ analyzer = CumsumNumKernelsAnalyzer (
46+ model_path , device , start_offset_in_original_graph
47+ )
4448 cumsum_num_kernels = analyzer .analyze ()
4549 cumsum_num_kernels_json = json .dumps (cumsum_num_kernels , indent = 4 )
4650 output_dir_path = Path (self .config ["output_dir" ]) / rel_model_path
4751 output_dir_path .mkdir (parents = True , exist_ok = True )
4852 output_file_path = output_dir_path / self .config ["output_json_file_name" ]
4953 output_file_path .write_text (cumsum_num_kernels_json )
5054
55+ def _choose_device (self , device ) -> str :
56+ if device in ["cpu" , "cuda" ]:
57+ return device
58+ return "cuda" if torch .cuda .is_available () else "cpu"
59+
5160
5261class CumsumNumKernelsAnalyzer :
53- def __init__ (self , model_path : Path , start_offset_in_original_graph : int ):
62+ def __init__ (
63+ self , model_path : Path , device : str , start_offset_in_original_graph : int
64+ ):
5465 self .model_path = model_path
66+ self .device = device
5567 self .start_offset_in_original_graph = start_offset_in_original_graph
5668
5769 def analyze (self ):
@@ -67,8 +79,12 @@ def analyze(self):
6779
6880 def _get_cumsum_num_kernels (self ):
6981 model_path = str (self .model_path )
70- module , inputs = get_torch_module_and_inputs (model_path , use_dummy_inputs = False )
71- gm = parse_immutable_model_path_into_sole_graph_module (model_path )
82+ module , inputs = get_torch_module_and_inputs (
83+ model_path , use_dummy_inputs = False , device = self .device
84+ )
85+ gm = parse_immutable_model_path_into_sole_graph_module (
86+ model_path , device = self .device
87+ )
7288 for start , end in self ._get_ranges (gm ):
7389 assert start == 0
7490 num_kernels = self ._get_num_kernels_if_submodule_compiled (
@@ -97,6 +113,7 @@ def compile_and_count_num_kernels(m, seq_no):
97113 group_head_and_tail = False ,
98114 chain_style = False ,
99115 )
116+
100117 rewrited_gm (* inputs )
101118 assert mut_opt_num_kernels .is_some ()
102119 return mut_opt_num_kernels .unwrap ()
0 commit comments