Skip to content

Commit 464e5a2

Browse files
Merge pull request #25569 from hawkinsp:numpyver
PiperOrigin-RevId: 707570246
2 parents 3f24dfd + ee45718 commit 464e5a2

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## Unreleased
1414

15+
* Changes:
16+
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
17+
supported version until June 2025.
18+
1519
## jax 0.4.38 (Dec 17, 2024)
1620

1721
* Changes:

jaxlib/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def has_ext_modules(self):
6363
install_requires=[
6464
'scipy>=1.10',
6565
"scipy>=1.11.1; python_version>='3.12'",
66-
'numpy>=1.24',
66+
'numpy>=1.25',
6767
'ml_dtypes>=0.2.0',
6868
],
6969
url='https://github.com/jax-ml/jax',

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def load_version_module(pkg_path):
5757
install_requires=[
5858
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
5959
'ml_dtypes>=0.4.0',
60-
'numpy>=1.24',
60+
'numpy>=1.25',
6161
"numpy>=1.26.0; python_version>='3.12'",
6262
'opt_einsum',
6363
'scipy>=1.10',

tests/array_interoperability_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
import numpy as np
2727

28-
numpy_version = jtu.numpy_version()
29-
3028
config.parse_flags_with_absl()
3129

3230
try:
@@ -48,10 +46,6 @@
4846
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
4947
key=lambda x: x.__name__)
5048

51-
# NumPy didn't support bool as a dlpack type until 1.25.
52-
if jtu.numpy_version() < (1, 25, 0):
53-
numpy_dtypes = [dt for dt in numpy_dtypes if dt != jnp.bool_]
54-
5549
cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]
5650

5751
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]

0 commit comments

Comments
 (0)