Skip to content

Commit 59925f5

Browse files
[release/2.7] Default unit tests fix (#2545)
Fixes inductor.test_cooperative_reductions.py Fixes test_quantization.py - Python integer 128 out of bounds for int8 --------- Co-authored-by: Philip Maybank <[email protected]>
1 parent cbf75ac commit 59925f5

File tree

2 files changed

+87
-8
lines changed

2 files changed

+87
-8
lines changed

test/inductor/test_cooperative_reductions.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
1313
from torch._inductor.test_case import TestCase
1414
from torch._inductor.utils import run_and_get_code
15+
from torch.testing import assert_close
1516
from torch.testing._internal.common_cuda import IS_SM89
1617
from torch.testing._internal.common_utils import (
1718
instantiate_parametrized_tests,
@@ -57,12 +58,90 @@ def setUp(self):
5758
torch._inductor.metrics.generated_kernel_count = 0
5859
torch._dynamo.reset()
5960

60-
def run_and_check(self, fn, args, *, expect_kernel_count=1):
61-
args_cpu = [tensor.cpu().to(torch.float32) for tensor in args]
62-
expected = fn(*args_cpu).to(torch.float16)
63-
fn = torch.compile(fn, fullgraph=True)
64-
result, (source_code,) = run_and_get_code(fn, *args)
65-
self.assertEqual(result, expected)
61+
def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1):
62+
# Define fixed tolerances
63+
RTOL = 1e-5
64+
ATOL = 1e-6
65+
66+
# calculate reference value in higher precision when input dtype is float16
67+
ref_dtype = dtype
68+
if dtype == torch.float16:
69+
ref_dtype = torch.float64
70+
71+
# Cast to the determined reference dtype
72+
args_ref = [tensor.to(ref_dtype) for tensor in args]
73+
74+
# Calculate expected output
75+
raw_expected = fn(*args_ref)
76+
77+
if isinstance(raw_expected, (tuple, list)):
78+
# If it's a tuple or list, apply .to(dtype) to each tensor within it
79+
# Also, handle cases where dtype might not be provided (e.g., for bool reductions)
80+
if dtype is not None:
81+
expected = type(raw_expected)(
82+
[
83+
t.to(dtype) if isinstance(t, torch.Tensor) else t
84+
for t in raw_expected
85+
]
86+
)
87+
else:
88+
expected = type(raw_expected)(
89+
[
90+
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
91+
for t in raw_expected
92+
]
93+
)
94+
else:
95+
# If it's a single tensor
96+
if dtype is not None:
97+
expected = raw_expected.to(dtype)
98+
else:
99+
expected = raw_expected.to(torch.float64)
100+
101+
fn_compiled = torch.compile(fn, fullgraph=True)
102+
result, (source_code,) = run_and_get_code(fn_compiled, *args)
103+
104+
# For comparison, ensure result is also a tuple/list if expected is
105+
if isinstance(expected, (tuple, list)):
106+
if isinstance(result, torch.Tensor):
107+
result = (result,)
108+
elif not isinstance(result, type(expected)):
109+
result = type(expected)(result)
110+
111+
if dtype is not None:
112+
result = type(result)(
113+
[t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result]
114+
)
115+
else:
116+
result = type(result)(
117+
[
118+
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
119+
for t in result
120+
]
121+
)
122+
else:
123+
if dtype is not None and isinstance(result, torch.Tensor):
124+
result = result.to(dtype)
125+
elif isinstance(result, torch.Tensor):
126+
result = result.to(torch.float64)
127+
128+
# Apply assert_close with fixed tolerances for tensor comparisons
129+
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
130+
assert_close(result, expected, rtol=RTOL, atol=ATOL)
131+
elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)):
132+
# Iterate through elements for comparison
133+
for r_item, e_item in zip(result, expected):
134+
if isinstance(r_item, torch.Tensor) and isinstance(
135+
e_item, torch.Tensor
136+
):
137+
assert_close(r_item, e_item, rtol=RTOL, atol=ATOL)
138+
else:
139+
# Fallback to assertEqual for non-tensor elements (e.g., bool, int)
140+
self.assertEqual(r_item, e_item)
141+
else:
142+
# Fallback to assertEqual for other types not handled by assert_close
143+
self.assertEqual(result, expected)
144+
66145
if "@triton_heuristics.fixed_config" in source_code:
67146
self.assertIn("cooperative_reduction_grid", source_code)
68147
else:
@@ -98,7 +177,7 @@ def fn(x, y):
98177

99178
reduction_fn = getattr(torch, name)
100179
args = [torch.randn(1, 1024**2, device="cuda", dtype=dtype) for _ in range(2)]
101-
self.run_and_check(fn, args)
180+
self.run_and_check(fn, args, dtype)
102181

103182
def test_bool_reduction_fns(self):
104183
def fn(x, y):

test/quantization/core/test_quantized_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class PointwisePostOp(NamedTuple):
6666
def avoid_vpmaddubsw_overflow_linear(
6767
batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
6868
):
69-
if sys.version_info >= (3, 13):
69+
if np.lib.NumpyVersion(np.__version__) >= '2.1.0':
7070
raise unittest.SkipTest("numpy 2.1 overflow error")
7171
for i, j in np.ndindex((batch_size, output_channels)):
7272
for k in range(0, input_channels // 2 * 2, 2):

0 commit comments

Comments
 (0)