@@ -442,10 +442,11 @@ def _matmat(self, X: NDArray) -> NDArray:
442442 Modified version of scipy _matmat to avoid having trailing dimension
443443 in col when provided to matvec
444444 """
445+ ncp = get_array_module (X )
445446 if sp .sparse .issparse (X ):
446- y = np .vstack ([self .matvec (col .toarray ().reshape (- 1 )) for col in X .T ]).T
447+ y = ncp .vstack ([self .matvec (col .toarray ().reshape (- 1 )) for col in X .T ]).T
447448 else :
448- y = np .vstack ([self .matvec (col .reshape (- 1 )) for col in X .T ]).T
449+ y = ncp .vstack ([self .matvec (col .reshape (- 1 )) for col in X .T ]).T
449450 return y
450451
451452 def _rmatmat (self , X : NDArray ) -> NDArray :
@@ -454,10 +455,11 @@ def _rmatmat(self, X: NDArray) -> NDArray:
454455 Modified version of scipy _rmatmat to avoid having trailing dimension
455456 in col when provided to rmatvec
456457 """
458+ ncp = get_array_module (X )
457459 if sp .sparse .issparse (X ):
458- y = np .vstack ([self .rmatvec (col .toarray ().reshape (- 1 )) for col in X .T ]).T
460+ y = ncp .vstack ([self .rmatvec (col .toarray ().reshape (- 1 )) for col in X .T ]).T
459461 else :
460- y = np .vstack ([self .rmatvec (col .reshape (- 1 )) for col in X .T ]).T
462+ y = ncp .vstack ([self .rmatvec (col .reshape (- 1 )) for col in X .T ]).T
461463 return y
462464
463465 def _adjoint (self ) -> LinearOperator :
@@ -509,7 +511,7 @@ def matvec(self, x: NDArray) -> NDArray:
509511
510512 if x .shape != (N ,) and x .shape != (N , 1 ):
511513 raise ValueError (
512- f"Dimension mismatch. Got { x .shape } , but expected { ( M , 1 ) } or { ( M ,) } ."
514+ f"Dimension mismatch. Got { x .shape } , but expected ( { N } ,) or ( { N } , 1) ."
513515 )
514516
515517 y = self ._matvec (x )
@@ -545,7 +547,7 @@ def rmatvec(self, x: NDArray) -> NDArray:
545547
546548 if x .shape != (M ,) and x .shape != (M , 1 ):
547549 raise ValueError (
548- f"Dimension mismatch. Got { x .shape } , but expected { ( M , 1 ) } or { ( M ,) } ."
550+ f"Dimension mismatch. Got { x .shape } , but expected ( { M } ,) or ( { M } , 1) ."
549551 )
550552
551553 y = self ._rmatvec (x )
@@ -795,7 +797,7 @@ def todense(
795797 Parameters
796798 ----------
797799 backend : :obj:`str`, optional
798- Backend used to densify matrix (``numpy`` or ``cupy``). Note that
800+ Backend used to densify matrix (``numpy`` or ``cupy`` or ``jax`` ). Note that
799801 this must be consistent with how the operator has been created.
800802
801803 Returns
@@ -820,7 +822,7 @@ def todense(
820822 if Op .shape [1 ] == shapemin :
821823 matrix = Op .matmat (identity )
822824 else :
823- matrix = np .conj (Op .rmatmat (identity )).T
825+ matrix = ncp .conj (Op .rmatmat (identity )).T
824826 return matrix
825827
826828 def tosparse (self ) -> NDArray :
0 commit comments