|
12 | 12 | from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel |
13 | 13 | from torch._inductor.test_case import TestCase |
14 | 14 | from torch._inductor.utils import run_and_get_code |
| 15 | +from torch.testing import assert_close |
15 | 16 | from torch.testing._internal.common_cuda import IS_SM89 |
16 | 17 | from torch.testing._internal.common_utils import ( |
17 | 18 | instantiate_parametrized_tests, |
@@ -57,12 +58,90 @@ def setUp(self): |
57 | 58 | torch._inductor.metrics.generated_kernel_count = 0 |
58 | 59 | torch._dynamo.reset() |
59 | 60 |
|
60 | | - def run_and_check(self, fn, args, *, expect_kernel_count=1): |
61 | | - args_cpu = [tensor.cpu().to(torch.float32) for tensor in args] |
62 | | - expected = fn(*args_cpu).to(torch.float16) |
63 | | - fn = torch.compile(fn, fullgraph=True) |
64 | | - result, (source_code,) = run_and_get_code(fn, *args) |
65 | | - self.assertEqual(result, expected) |
| 61 | + def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1): |
| 62 | + # Define fixed tolerances |
| 63 | + RTOL = 1e-5 |
| 64 | + ATOL = 1e-6 |
| 65 | + |
| 66 | + # calculate reference value in higher precision when input dtype is float16 |
| 67 | + ref_dtype = dtype |
| 68 | + if dtype == torch.float16: |
| 69 | + ref_dtype = torch.float64 |
| 70 | + |
| 71 | + # Cast to the determined reference dtype |
| 72 | + args_ref = [tensor.to(ref_dtype) for tensor in args] |
| 73 | + |
| 74 | + # Calculate expected output |
| 75 | + raw_expected = fn(*args_ref) |
| 76 | + |
| 77 | + if isinstance(raw_expected, (tuple, list)): |
| 78 | + # If it's a tuple or list, apply .to(dtype) to each tensor within it |
| 79 | + # Also, handle cases where dtype might not be provided (e.g., for bool reductions) |
| 80 | + if dtype is not None: |
| 81 | + expected = type(raw_expected)( |
| 82 | + [ |
| 83 | + t.to(dtype) if isinstance(t, torch.Tensor) else t |
| 84 | + for t in raw_expected |
| 85 | + ] |
| 86 | + ) |
| 87 | + else: |
| 88 | + expected = type(raw_expected)( |
| 89 | + [ |
| 90 | + t.to(torch.float64) if isinstance(t, torch.Tensor) else t |
| 91 | + for t in raw_expected |
| 92 | + ] |
| 93 | + ) |
| 94 | + else: |
| 95 | + # If it's a single tensor |
| 96 | + if dtype is not None: |
| 97 | + expected = raw_expected.to(dtype) |
| 98 | + else: |
| 99 | + expected = raw_expected.to(torch.float64) |
| 100 | + |
| 101 | + fn_compiled = torch.compile(fn, fullgraph=True) |
| 102 | + result, (source_code,) = run_and_get_code(fn_compiled, *args) |
| 103 | + |
| 104 | + # For comparison, ensure result is also a tuple/list if expected is |
| 105 | + if isinstance(expected, (tuple, list)): |
| 106 | + if isinstance(result, torch.Tensor): |
| 107 | + result = (result,) |
| 108 | + elif not isinstance(result, type(expected)): |
| 109 | + result = type(expected)(result) |
| 110 | + |
| 111 | + if dtype is not None: |
| 112 | + result = type(result)( |
| 113 | + [t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result] |
| 114 | + ) |
| 115 | + else: |
| 116 | + result = type(result)( |
| 117 | + [ |
| 118 | + t.to(torch.float64) if isinstance(t, torch.Tensor) else t |
| 119 | + for t in result |
| 120 | + ] |
| 121 | + ) |
| 122 | + else: |
| 123 | + if dtype is not None and isinstance(result, torch.Tensor): |
| 124 | + result = result.to(dtype) |
| 125 | + elif isinstance(result, torch.Tensor): |
| 126 | + result = result.to(torch.float64) |
| 127 | + |
| 128 | + # Apply assert_close with fixed tolerances for tensor comparisons |
| 129 | + if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): |
| 130 | + assert_close(result, expected, rtol=RTOL, atol=ATOL) |
| 131 | + elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)): |
| 132 | + # Iterate through elements for comparison |
| 133 | + for r_item, e_item in zip(result, expected): |
| 134 | + if isinstance(r_item, torch.Tensor) and isinstance( |
| 135 | + e_item, torch.Tensor |
| 136 | + ): |
| 137 | + assert_close(r_item, e_item, rtol=RTOL, atol=ATOL) |
| 138 | + else: |
| 139 | + # Fallback to assertEqual for non-tensor elements (e.g., bool, int) |
| 140 | + self.assertEqual(r_item, e_item) |
| 141 | + else: |
| 142 | + # Fallback to assertEqual for other types not handled by assert_close |
| 143 | + self.assertEqual(result, expected) |
| 144 | + |
66 | 145 | if "@triton_heuristics.fixed_config" in source_code: |
67 | 146 | self.assertIn("cooperative_reduction_grid", source_code) |
68 | 147 | else: |
@@ -98,7 +177,7 @@ def fn(x, y): |
98 | 177 |
|
99 | 178 | reduction_fn = getattr(torch, name) |
100 | 179 | args = [torch.randn(1, 1024**2, device="cuda", dtype=dtype) for _ in range(2)] |
101 | | - self.run_and_check(fn, args) |
| 180 | + self.run_and_check(fn, args, dtype) |
102 | 181 |
|
103 | 182 | def test_bool_reduction_fns(self): |
104 | 183 | def fn(x, y): |
|
0 commit comments