@@ -30,6 +30,7 @@ def declare_config(
3030 subgraph_ranges_json_key : str = "subgraph_ranges" ,
3131 group_head_and_tail : bool = False ,
3232 chain_style : bool = False ,
33+ device : str = "auto" ,
3334 resume : bool = False ,
3435 limits_handled_models : int = None ,
3536 ):
@@ -63,7 +64,10 @@ def _has_enough_subgraphs(self, rel_model_path, num_subgraphs):
6364 def resume (self , rel_model_path : str ):
6465 model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
6566 torch .cuda .empty_cache ()
66- module , inputs = get_torch_module_and_inputs (model_path , use_dummy_inputs = False )
67+ device = self ._choose_device (self .config ["device" ])
68+ module , inputs = get_torch_module_and_inputs (
69+ model_path , use_dummy_inputs = False , device = device
70+ )
6771 gm = parse_sole_graph_module (module , inputs )
6872 torch .cuda .empty_cache ()
6973 subgraph_ranges = self ._get_subgraph_ranges (rel_model_path )
@@ -105,6 +109,11 @@ def fn(submodule, seq_no):
105109
106110 return fn
107111
112+ def _choose_device (self , device ) -> str :
113+ if device in ["cpu" , "cuda" ]:
114+ return device
115+ return "cuda" if torch .cuda .is_available () else "cpu"
116+
108117
109118class NaiveDecomposerExtractorModule (torch .nn .Module ):
110119 def __init__ (
0 commit comments