Skip to content

Commit 14adc13

Browse files
committed
feat: enable jax in todense
1 parent c9d0eae commit 14adc13

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

pylops/jaxoperator.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@
2424

2525

2626
class JaxOperator(LinearOperator):
27+
"""Enable JAX backend for PyLops operator.
28+
29+
This class can be used to wrap a pylops operator to enable the JAX
30+
backend. Doing so, users can run all of the methods of a pylops
31+
operator with JAX arrays. Moreover, the forward and adjoint
32+
are internally just-in-time compiled, and other JAX functionalities
33+
such as automatic differentiation and automatic vectorization
34+
are enabled.
35+
36+
Parameters
37+
----------
38+
Op : :obj:`pylops.LinearOperator`
39+
PyLops operator
40+
41+
"""
42+
2743
def __init__(self, Op: LinearOperator) -> None:
2844
if not deps.jax_enabled:
2945
raise NotImplementedError(jax_message)
@@ -43,7 +59,12 @@ def __call__(self, x, *args, **kwargs):
4359
return self._matvec(x)
4460

4561
def _rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
46-
"""Vector-Jacobian products
62+
_, f_vjp = jax.vjp(self._matvec, x)
63+
xadj = jax.jit(f_vjp)(y)[0]
64+
return xadj
65+
66+
def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
67+
"""Vector-Jacobian product
4768
4869
JIT-compiled Vector-Jacobian product
4970
@@ -59,33 +80,12 @@ def _rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
5980
xadj : :obj:`jaxlib.xla_extension.ArrayImpl`
6081
Output array
6182
62-
"""
63-
_, f_vjp = jax.vjp(self._matvec, x)
64-
xadj = jax.jit(f_vjp)(y)[0]
65-
return xadj
66-
67-
def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
68-
"""Adjoint matrix-vector multiplication with AD
69-
70-
Parameters
71-
----------
72-
x : :obj:`jaxlib.xla_extension.ArrayImpl`
73-
Input array
74-
y : :obj:`jaxlib.xla_extension.ArrayImpl`
75-
Output array (where to store the
76-
Vector-Jacobian product)
77-
78-
Returns
79-
-------
80-
x : :obj:`numpy.ndarray`
81-
Output array of shape (N,) or (N,1)
82-
8383
"""
8484
M, N = self.shape
8585

8686
if x.shape != (M,) and x.shape != (M, 1):
8787
raise ValueError(
88-
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
88+
f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
8989
)
9090

9191
y = self._rmatvecad(x, y)

pylops/linearoperator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,11 @@ def _matmat(self, X: NDArray) -> NDArray:
442442
Modified version of scipy _matmat to avoid having trailing dimension
443443
in col when provided to matvec
444444
"""
445+
ncp = get_array_module(X)
445446
if sp.sparse.issparse(X):
446-
y = np.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
447+
y = ncp.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
447448
else:
448-
y = np.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
449+
y = ncp.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
449450
return y
450451

451452
def _rmatmat(self, X: NDArray) -> NDArray:
@@ -454,10 +455,11 @@ def _rmatmat(self, X: NDArray) -> NDArray:
454455
Modified version of scipy _rmatmat to avoid having trailing dimension
455456
in col when provided to rmatvec
456457
"""
458+
ncp = get_array_module(X)
457459
if sp.sparse.issparse(X):
458-
y = np.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
460+
y = ncp.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
459461
else:
460-
y = np.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
462+
y = ncp.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
461463
return y
462464

463465
def _adjoint(self) -> LinearOperator:
@@ -509,7 +511,7 @@ def matvec(self, x: NDArray) -> NDArray:
509511

510512
if x.shape != (N,) and x.shape != (N, 1):
511513
raise ValueError(
512-
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
514+
f"Dimension mismatch. Got {x.shape}, but expected ({N},) or ({N}, 1)."
513515
)
514516

515517
y = self._matvec(x)
@@ -545,7 +547,7 @@ def rmatvec(self, x: NDArray) -> NDArray:
545547

546548
if x.shape != (M,) and x.shape != (M, 1):
547549
raise ValueError(
548-
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
550+
f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
549551
)
550552

551553
y = self._rmatvec(x)
@@ -795,7 +797,7 @@ def todense(
795797
Parameters
796798
----------
797799
backend : :obj:`str`, optional
798-
Backend used to densify matrix (``numpy`` or ``cupy``). Note that
800+
Backend used to densify matrix (``numpy`` or ``cupy`` or ``jax``). Note that
799801
this must be consistent with how the operator has been created.
800802
801803
Returns
@@ -820,7 +822,7 @@ def todense(
820822
if Op.shape[1] == shapemin:
821823
matrix = Op.matmat(identity)
822824
else:
823-
matrix = np.conj(Op.rmatmat(identity)).T
825+
matrix = ncp.conj(Op.rmatmat(identity)).T
824826
return matrix
825827

826828
def tosparse(self) -> NDArray:

0 commit comments

Comments
 (0)