How does deprecation of jax _src/scipy
functions work?
#18995
-
In #17182 @hawkinsp deprecated $ docker run --rm -ti python:3.11 /bin/bash
root@776f0a340338:/# python -m venv venv && . venv/bin/activate
(venv) root@776f0a340338:/# python -m pip --quiet install --upgrade pip wheel
(venv) root@776f0a340338:/# python -m pip --quiet install --upgrade scipy jax jaxlib
(venv) root@776f0a340338:/# python -m pip list | grep 'scipy\|jax'
jax 0.4.23
jaxlib 0.4.23
scipy 1.11.4
(venv) root@776f0a340338:/# python -c 'from scipy.linalg import tril; import jax.scipy.linalg' # All good
(venv) root@776f0a340338:/# python -m pip install --upgrade --pre --index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple --extra-index-url https://pypi.org/simple scipy
Looking in indexes: https://pypi.anaconda.org/scientific-python-nightly-wheels/simple, https://pypi.org/simple
Requirement already satisfied: scipy in /venv/lib/python3.11/site-packages (1.11.4)
Collecting scipy
Downloading https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/scipy/1.12.0.dev0/scipy-1.12.0.dev0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.1/38.1 MB 14.5 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.22.4 in /venv/lib/python3.11/site-packages (from scipy) (1.26.2)
Installing collected packages: scipy
Attempting uninstall: scipy
Found existing installation: scipy 1.11.4
Uninstalling scipy-1.11.4:
Successfully uninstalled scipy-1.11.4
Successfully installed scipy-1.12.0.dev0
(venv) root@776f0a340338:/# python -c 'from scipy.linalg import tril' # tril removed in the scipy nigtly
Traceback (most recent call last):
File "<string>", line 1, in <module>
ImportError: cannot import name 'tril' from 'scipy.linalg' (/venv/lib/python3.11/site-packages/scipy/linalg/__init__.py)
(venv) root@776f0a340338:/# python -c 'import jax.scipy.linalg' # which now breaks jax
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/venv/lib/python3.11/site-packages/jax/scipy/linalg.py", line 18, in <module>
from jax._src.scipy.linalg import (
File "/venv/lib/python3.11/site-packages/jax/_src/scipy/linalg.py", line 403, in <module>
@_wraps(scipy.linalg.tril)
^^^^^^^^^^^^^^^^^
AttributeError: module 'scipy.linalg' has no attribute 'tril'
(venv) root@776f0a340338:/# I haven't taken the time to step through how |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi - our API compatibility policy specifies that deprecated functions will raise deprecaton warnings for at least 3 months before they are removed. #17182 was part of the 0.4.16 release on Sept 18, so the earliest we would remove them is on Dec 18, which is Monday. I'm currently working on a branch that removes these and a number of other deprecations, and I hope to send it for review on Monday. Regarding nightly CI failures: up unti recently, we had a CI job that ran tests against upstream numpy and scipy. That began failing when the NumPy 2.0 ABI changed. It's currently not possible to build JAX against numpy 2.0 because pybind11 is not yet compatible with numpy 2.0; as soon as that is worked out, we'll reinstate the numpy/scipy nightly builds and catch this kind of incompatibility earlier. |
Beta Was this translation helpful? Give feedback.
Hi - our API compatibility policy specifies that deprecated functions will raise deprecaton warnings for at least 3 months before they are removed. #17182 was part of the 0.4.16 release on Sept 18, so the earliest we would remove them is on Dec 18, which is Monday. I'm currently working on a branch that removes these and a number of other deprecations, and I hope to send it for review on Monday.
Regarding nightly CI failures: up unti recently, we had a CI job that ran tests against upstream numpy and scipy. That began failing when the NumPy 2.0 ABI changed. It's currently not possible to build JAX against numpy 2.0 because pybind11 is not yet compatible with numpy 2.0; as soon as that is …