Skip to content

Commit 310b2a9

Browse files
committed
minor: one more fix of typing of jax
1 parent 29a4dc2 commit 310b2a9

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pylops/jaxoperator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@
1010
if deps.jax_enabled:
1111
import jax
1212

13-
jaxarray_type = jax.Array
13+
jaxarrayin_type = jax.typing.ArrayLike
14+
jaxarrayout_type = jax.Array
1415
else:
1516
jax_message = (
1617
"JAX package not installed. In order to be able to use"
1718
'the jaxoperator module run "pip install jax" or'
1819
'"conda install -c conda-forge jax".'
1920
)
20-
jaxarray_type = Any
21+
jaxarrayin_type = None
22+
jaxarrayout_type = None
2123

22-
JaxType = NewType("JaxType", jaxarray_type)
24+
JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type)
25+
JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type)
2326

2427

2528
class JaxOperator(LinearOperator):
@@ -57,12 +60,12 @@ def __init__(self, Op: LinearOperator) -> None:
5760
def __call__(self, x, *args, **kwargs):
5861
return self._matvec(x)
5962

60-
def _rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
63+
def _rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
6164
_, f_vjp = jax.vjp(self._matvec, x)
6265
xadj = jax.jit(f_vjp)(y)[0]
6366
return xadj
6467

65-
def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
68+
def rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
6669
"""Vector-Jacobian product
6770
6871
JIT-compiled Vector-Jacobian product
@@ -76,7 +79,7 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
7679
7780
Returns
7881
-------
79-
xadj : :obj:`jax.Array`
82+
xadj : :obj:`jax.typing.ArrayLike`
8083
Output array
8184
8285
"""

0 commit comments

Comments
 (0)