|
8 | 8 | from typing import Optional |
9 | 9 | from typing import Sequence |
10 | 10 |
|
11 | | -import jax |
12 | 11 | import numpy as np |
13 | 12 | from scipy import sparse |
14 | 13 | from sklearn.base import TransformerMixin |
@@ -145,31 +144,32 @@ def x_sequence_or_item(wrapped_func): |
145 | 144 | @wraps(wrapped_func) |
146 | 145 | def func(self, x, *args, **kwargs): |
147 | 146 | if isinstance(x, Sequence): |
148 | | - if isinstance(x[0], jax.Array): |
149 | | - xs = x |
150 | | - else: |
| 147 | + if isinstance(x[0], np.ndarray): |
151 | 148 | xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] |
| 149 | + else: |
| 150 | + # e.g. jax array |
| 151 | + xs = x |
152 | 152 | result = wrapped_func(self, xs, *args, **kwargs) |
153 | 153 | # if transform() is a normal "return x" |
154 | 154 | if isinstance(result, Sequence) and isinstance(result[0], np.ndarray): |
155 | 155 | return [AxesArray(xp, comprehend_axes(xp)) for xp in result] |
156 | 156 | return result # e.g. fit() returns self |
157 | 157 | else: |
158 | | - if isinstance(x, jax.Array): |
159 | | - |
160 | | - def reconstructor(x): |
161 | | - return x |
162 | | - |
163 | | - elif not sparse.issparse(x) and isinstance(x, np.ndarray): |
| 158 | + if not sparse.issparse(x) and isinstance(x, np.ndarray): |
164 | 159 | x = AxesArray(x, comprehend_axes(x)) |
165 | 160 |
|
166 | 161 | def reconstructor(x): |
167 | 162 | return x |
168 | 163 |
|
169 | | - else: # sparse |
| 164 | + elif sparse.issparse(x): |
170 | 165 | reconstructor = type(x) |
171 | 166 | axes = comprehend_axes(x) |
172 | 167 | wrap_axes(axes, x) |
| 168 | + else: # e.g. jax array |
| 169 | + |
| 170 | + def reconstructor(x): |
| 171 | + return x |
| 172 | + |
173 | 173 | result = wrapped_func(self, [x], *args, **kwargs) |
174 | 174 | if isinstance(result, Sequence): # e.g. transform() returns x |
175 | 175 | return reconstructor(result[0]) |
|
0 commit comments