Skip to content

Commit fcc063a

Browse files
BLD: Jax is used in SBR optional, not main (#622)
1 parent 069afc9 commit fcc063a

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ classifiers = [
2626
]
2727
readme = "README.rst"
2828
dependencies = [
29-
"jax>=0.4,<0.5",
3029
"scikit-learn>=1.1, !=1.5.0, !=1.6.0",
3130
"numpy<2.0",
3231
"derivative>=0.6.2",
@@ -69,7 +68,8 @@ cvxpy = [
6968
sbr = [
7069
"numpyro",
7170
"arviz==0.17.1",
72-
"scipy<1.13.0"
71+
"scipy<1.13.0",
72+
"jax>=0.4,<0.5"
7373
]
7474

7575
[tool.black]

pysindy/feature_library/base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Optional
99
from typing import Sequence
1010

11-
import jax
1211
import numpy as np
1312
from scipy import sparse
1413
from sklearn.base import TransformerMixin
@@ -145,31 +144,32 @@ def x_sequence_or_item(wrapped_func):
145144
@wraps(wrapped_func)
146145
def func(self, x, *args, **kwargs):
147146
if isinstance(x, Sequence):
148-
if isinstance(x[0], jax.Array):
149-
xs = x
150-
else:
147+
if isinstance(x[0], np.ndarray):
151148
xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x]
149+
else:
150+
# e.g. jax array
151+
xs = x
152152
result = wrapped_func(self, xs, *args, **kwargs)
153153
# if transform() is a normal "return x"
154154
if isinstance(result, Sequence) and isinstance(result[0], np.ndarray):
155155
return [AxesArray(xp, comprehend_axes(xp)) for xp in result]
156156
return result # e.g. fit() returns self
157157
else:
158-
if isinstance(x, jax.Array):
159-
160-
def reconstructor(x):
161-
return x
162-
163-
elif not sparse.issparse(x) and isinstance(x, np.ndarray):
158+
if not sparse.issparse(x) and isinstance(x, np.ndarray):
164159
x = AxesArray(x, comprehend_axes(x))
165160

166161
def reconstructor(x):
167162
return x
168163

169-
else: # sparse
164+
elif sparse.issparse(x):
170165
reconstructor = type(x)
171166
axes = comprehend_axes(x)
172167
wrap_axes(axes, x)
168+
else: # e.g. jax array
169+
170+
def reconstructor(x):
171+
return x
172+
173173
result = wrapped_func(self, [x], *args, **kwargs)
174174
if isinstance(result, Sequence): # e.g. transform() returns x
175175
return reconstructor(result[0])

0 commit comments

Comments
 (0)