@@ -150,12 +150,14 @@ def get_len(tid):
150150 get_len (sym_id )
151151 return token2len
152152
153- def analyze (self , model_paths_file : str , device : str ) -> Dict [str , Dict ]:
153+ def analyze (
154+ self , model_path_prefix : str , model_paths_file : str , device : str
155+ ) -> Dict [str , Dict ]:
154156 input_file = Path (model_paths_file )
155157
156158 with open (input_file , "r" ) as f :
157159 model_paths = [
158- Path (line .strip () )
160+ Path (model_path_prefix ) / line .strip ()
159161 for line in f
160162 if line .strip () and not line .startswith ("#" )
161163 ]
@@ -264,7 +266,7 @@ def main(args):
264266 fold_policy = args .fold_policy ,
265267 fold_times = args .fold_times ,
266268 )
267- results = analyzer .analyze (args .model_list , args .device )
269+ results = analyzer .analyze (args .model_path_prefix , args . model_list , args .device )
268270 if args .output_json :
269271 with open (args .output_json , "w" ) as f :
270272 json .dump (results , f , indent = 4 )
@@ -280,6 +282,12 @@ def main(args):
280282 required = True ,
281283 help = "Path to a text file containing paths to models (one per line)." ,
282284 )
285+ parser .add_argument (
286+ "--model-path-prefix" ,
287+ type = str ,
288+ default = "./" ,
289+ help = "Prefix to add to each model path in the list." ,
290+ )
283291 parser .add_argument (
284292 "--device" ,
285293 type = str ,
0 commit comments