@@ -199,12 +199,12 @@ def get_timing_stats(elapsed_times: list[float]):
199199
200200
201201def measure_performance (model_call , args , compiler ):
202- if args . device == "cuda" :
202+ if "cuda" in args . device :
203203 times = time_execution_with_cuda_event (
204204 model_call ,
205205 num_warmup = args .warmup ,
206206 num_trials = args .trials ,
207- device = torch .device ("cuda:0" ),
207+ device = torch .device (args . device ),
208208 )
209209 else :
210210 times = time_execution_naive (
@@ -243,8 +243,10 @@ def test_single_model(args):
243243 },
244244 }
245245
246- if args .device == "cuda" :
247- result_data ["configuration" ]["hardware" ] = torch .cuda .get_device_name (0 )
246+ if "cuda" in args .device :
247+ result_data ["configuration" ]["hardware" ] = torch .cuda .get_device_name (
248+ args .device
249+ )
248250 elif args .device == "cpu" :
249251 result_data ["configuration" ]["hardware" ] = platform .processor ()
250252 else :
@@ -335,38 +337,42 @@ def print_and_store_cmp(key, func, **kwargs):
335337
336338def get_cmp_equal (expected_out , compiled_out ):
337339 return " " .join (
338- str (int (torch .equal (a , b ))) for a , b in zip (expected_out , compiled_out )
340+ str (int (torch .equal (a .cpu (), b .cpu ())))
341+ for a , b in zip (expected_out , compiled_out )
339342 )
340343
341344
342345def get_cmp_all_close (expected_out , compiled_out , atol , rtol ):
343346 return " " .join (
344- str (int (torch .allclose (a , b , atol = atol , rtol = rtol )))
347+ str (int (torch .allclose (a . cpu () , b . cpu () , atol = atol , rtol = rtol )))
345348 for a , b in zip (expected_out , compiled_out )
346349 )
347350
348351
349352def get_cmp_max_diff (expected_out , compiled_out ):
350353 return " " .join (
351- str (torch .max (torch .abs (a .float () - b .float ())).item ())
354+ str (torch .max (torch .abs (a .cpu (). float () - b . cpu () .float ())).item ())
352355 for a , b in zip (expected_out , compiled_out )
353356 )
354357
355358
356359def get_cmp_mean_diff (expected_out , compiled_out ):
357360 return " " .join (
358- str (torch .mean (torch .abs (a .float () - b .float ())).item ())
361+ str (torch .mean (torch .abs (a .cpu (). float () - b . cpu () .float ())).item ())
359362 for a , b in zip (expected_out , compiled_out )
360363 )
361364
362365
363366def get_cmp_diff_count (expected_out , compiled_out , atol , rtol ):
364367 results = []
365368 for a , b in zip (expected_out , compiled_out ):
366- if a .is_floating_point () and b .is_floating_point ():
367- diff_count = torch .sum (~ torch .isclose (a , b , atol = atol , rtol = rtol )).item ()
369+ a_cpu , b_cpu = a .cpu (), b .cpu ()
370+ if a_cpu .is_floating_point () and b_cpu .is_floating_point ():
371+ diff_count = torch .sum (
372+ ~ torch .isclose (a_cpu , b_cpu , atol = atol , rtol = rtol )
373+ ).item ()
368374 else :
369- diff_count = torch .sum (a != b ).item ()
375+ diff_count = torch .sum (a_cpu != b_cpu ).item ()
370376 results .append (str (diff_count ))
371377 return " " .join (results )
372378
0 commit comments