Skip to content

Commit 5a8a20b

Browse files
committed
minor: fixed typing of jax
1 parent 6af08b6 commit 5a8a20b

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pylops/jaxoperator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99

1010
if deps.jax_enabled:
1111
import jax
12-
import jaxlib
1312

14-
jaxarray_type = jaxlib.xla_extension.ArrayImpl
13+
jaxarray_type = jax.Array
1514
else:
1615
jax_message = (
1716
"JAX package not installed. In order to be able to use"
@@ -70,14 +69,14 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
7069
7170
Parameters
7271
----------
73-
x : :obj:`jaxlib.xla_extension.ArrayImpl`
72+
x : :obj:`jax.Array`
7473
Input array for forward
75-
y : :obj:`jaxlib.xla_extension.ArrayImpl`
74+
y : :obj:`jax.Array`
7675
Input array for adjoint
7776
7877
Returns
7978
-------
80-
xadj : :obj:`jaxlib.xla_extension.ArrayImpl`
79+
xadj : :obj:`jax.Array`
8180
Output array
8281
8382
"""

0 commit comments

Comments
 (0)