1010if deps .jax_enabled :
1111 import jax
1212
13- jaxarray_type = jax .Array
13+ jaxarrayin_type = jax .typing .ArrayLike
14+ jaxarrayout_type = jax .Array
1415else :
1516 jax_message = (
1617 "JAX package not installed. In order to be able to use"
1718 'the jaxoperator module run "pip install jax" or'
1819 '"conda install -c conda-forge jax".'
1920 )
20- jaxarray_type = Any
21+ jaxarrayin_type = None
22+ jaxarrayout_type = None
2123
22- JaxType = NewType ("JaxType" , jaxarray_type )
24+ JaxTypeIn = NewType ("JaxTypeIn" , jaxarrayin_type )
25+ JaxTypeOut = NewType ("JaxTypeOut" , jaxarrayout_type )
2326
2427
2528class JaxOperator (LinearOperator ):
@@ -57,12 +60,12 @@ def __init__(self, Op: LinearOperator) -> None:
5760 def __call__ (self , x , * args , ** kwargs ):
5861 return self ._matvec (x )
5962
60- def _rmatvecad (self , x : JaxType , y : JaxType ) -> JaxType :
63+ def _rmatvecad (self , x : JaxTypeIn , y : JaxTypeIn ) -> JaxTypeOut :
6164 _ , f_vjp = jax .vjp (self ._matvec , x )
6265 xadj = jax .jit (f_vjp )(y )[0 ]
6366 return xadj
6467
65- def rmatvecad (self , x : JaxType , y : JaxType ) -> JaxType :
68+ def rmatvecad (self , x : JaxTypeIn , y : JaxTypeIn ) -> JaxTypeOut :
6669 """Vector-Jacobian product
6770
6871 JIT-compiled Vector-Jacobian product
@@ -76,7 +79,7 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
7679
7780 Returns
7881 -------
79- xadj : :obj:`jax.Array `
82+ xadj : :obj:`jax.typing.ArrayLike `
8083 Output array
8184
8285 """
0 commit comments