Skip to content

Commit 504c738

Browse files
committed
Use next to tiny as smallest floating point value on Mac ARM
1 parent dc11d40 commit 504c738

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

jax/_src/test_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import logging
2626
import math
2727
import os
28+
import platform
2829
import re
2930
import sys
3031
import tempfile
@@ -1704,6 +1705,10 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
17041705
size_im = size_re
17051706
finfo = np.finfo(dtype)
17061707

1708+
machine = platform.machine()
1709+
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
1710+
smallest = np.nextafter(finfo.tiny, finfo.max) if is_arm_cpu and platform.system() == 'Darwin' else finfo.tiny
1711+
17071712
def make_axis_points(size):
17081713
prec_dps_ratio = 3.3219280948873626
17091714
logmin = logmax = finfo.maxexp / prec_dps_ratio
@@ -1722,8 +1727,8 @@ def make_axis_points(size):
17221727
axis_points[1] = finfo.min
17231728
axis_points[-2] = finfo.max
17241729
if size > 0:
1725-
axis_points[size] = -finfo.tiny
1726-
axis_points[-size - 1] = finfo.tiny
1730+
axis_points[size] = -smallest
1731+
axis_points[-size - 1] = smallest
17271732
axis_points[0] = -np.inf
17281733
axis_points[-1] = np.inf
17291734
return axis_points

tests/lax_test.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4400,34 +4400,14 @@ def regions_with_inaccuracies_keep(*to_keep):
44004400
elif name == 'tanh':
44014401
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
44024402

4403-
elif name == 'arcsin':
4404-
if is_arm_cpu and platform.system() == 'Darwin':
4405-
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real')
4406-
else:
4407-
regions_with_inaccuracies.clear()
4408-
4409-
elif name == 'arcsinh':
4410-
if is_arm_cpu and platform.system() == 'Darwin':
4411-
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag',
4412-
'negj.imag', 'posj.imag')
4413-
else:
4414-
regions_with_inaccuracies.clear()
4415-
44164403
elif name == 'arccos':
44174404
regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real')
44184405

44194406
elif name in {'cos', 'sin'}:
44204407
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
44214408

4422-
elif name == 'log1p':
4423-
if is_arm_cpu and platform.system() == 'Darwin':
4424-
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag',
4425-
'posj.imag')
4426-
else:
4427-
regions_with_inaccuracies.clear()
4428-
4429-
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan',
4430-
'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}:
4409+
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', 'log1p',
4410+
'arcsin', 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}:
44314411
regions_with_inaccuracies.clear()
44324412
else:
44334413
assert 0 # unreachable

0 commit comments

Comments
 (0)