@@ -348,23 +348,27 @@ def get_cmp_all_close(expected_out, compiled_out, atol, rtol):
348348
349349def get_cmp_max_diff (expected_out , compiled_out ):
350350 return " " .join (
351- str (torch .max (torch .abs (a - b )).item ())
351+ str (torch .max (torch .abs (a . float () - b . float () )).item ())
352352 for a , b in zip (expected_out , compiled_out )
353353 )
354354
355355
356356def get_cmp_mean_diff (expected_out , compiled_out ):
357357 return " " .join (
358- str (torch .mean (torch .abs (a - b )).item ())
358+ str (torch .mean (torch .abs (a . float () - b . float () )).item ())
359359 for a , b in zip (expected_out , compiled_out )
360360 )
361361
362362
363363def get_cmp_diff_count (expected_out , compiled_out , atol , rtol ):
364- return " " .join (
365- str (torch .sum (~ torch .isclose (a , b , atol = atol , rtol = rtol )).item ())
366- for a , b in zip (expected_out , compiled_out )
367- )
364+ results = []
365+ for a , b in zip (expected_out , compiled_out ):
366+ if a .is_floating_point () and b .is_floating_point ():
367+ diff_count = torch .sum (~ torch .isclose (a , b , atol = atol , rtol = rtol )).item ()
368+ else :
369+ diff_count = torch .sum (a != b ).item ()
370+ results .append (str (diff_count ))
371+ return " " .join (results )
368372
369373
370374def test_multi_models (args ):
0 commit comments