@@ -696,21 +696,42 @@ class OMP(Solver):
696696 \DeclareMathOperator*{\argmin}{arg\,min}
697697 \DeclareMathOperator*{\argmax}{arg\,max}
698698 \Lambda_k = \Lambda_{k-1} \cup \left\{\argmax_j
699- \left|\mathbf{Op}_j^H \,\mathbf{r}_k\right| \right\} \\
699+ \left|\mathbf{Op}^{j H} \,\mathbf{r}_k\right| \right\} \\
700700 \mathbf{x}_k = \argmin_{\mathbf{x}}
701701 \left\|\mathbf{Op}_{\Lambda_k}\,\mathbf{x} - \mathbf{y}\right\|_2^2
702702
703+ where :math:`\mathbf{Op}^j` is the :math:`j`-th column of the operator,
704+ :math:`\mathbf{r}_k` is the residual at iteration :math:`k`, and
705+ :math:`\mathbf{Op}_{\Lambda_k}` is the operator restricted to the columns
706+ in the set :math:`\Lambda_k`.
707+
703708 Note that by choosing ``niter_inner=0`` the basic Matching Pursuit (MP)
704709 algorithm is implemented instead. In other words, instead of solving an
705710 optimization at each iteration to find the best :math:`\mathbf{x}` for the
706- currently selected basis functions, the vector :math:`\mathbf{x}` is just
707- updated at the new basis function by taking directly the value from
708- the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
709-
710- In this case it is highly recommended to provide a normalized basis
711- function. If different basis have different norms, the solver is likely
712- to diverge. Similar observations apply to OMP, even though mild unbalancing
713- between the basis is generally properly handled.
711+ currently selected basis functions, either the vector :math:`\mathbf{x}`
712+ is just updated at the new basis function by adding the value from
713+ the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k` to the current value
714+ (``optimal_coeff=False``) or the optimal coefficient that minimizes the norm
715+ of the residual :math:`\mathbf{r} - c * \mathbf{Op}^j` is estimated
716+ (``optimal_coeff=True``) and added to the current value.
717+
718+ In the case the MP solver is used, it is highly recommended to provide a
719+ normalized basis function. If different basis have different norms, the
720+ solver is likely to diverge. Similar observations apply to OMP, even
721+ though mild unbalancing between the basis is generally properly handled.
722+ Two possible ways to handle the scenario fo non-normalized basis functions
723+ are:
724+
725+ - Find the normalization factor of the the basis functions before
726+ running the solver (this is done by choosing ``normalizecols=True``);
727+ - Find the optimal coefficient that minimizes the norm of the residual
728+ :math:`\mathbf{r} - c * \mathbf{Op}^j` at every iteration (this is
729+ done by choosing ``optimal_coeff=True``).
730+
731+ Finally, when the operator is a chain of operators, with the rigth-most
732+ representing the basis function, if the operator of the basis function is
733+ provided in the ``Opbasis`` parameter, the solver will use this operator
734+ to find the normalization factor for each column of the operator.
714735
715736 """
716737
@@ -742,6 +763,8 @@ def setup(
742763 niter_inner : int = 40 ,
743764 sigma : float = 1e-4 ,
744765 normalizecols : bool = False ,
766+ Opbasis : Optional ["LinearOperator" ] = None ,
767+ optimal_coeff : bool = False ,
745768 show : bool = False ,
746769 ) -> None :
747770 r"""Setup solver
@@ -763,6 +786,14 @@ def setup(
763786 :math:`n_{cols}` times to unit vectors (i.e., containing 1 at
764787 position j and zero otherwise); use only when the columns of the
765788 operator are expected to have highly varying norms.
789+ Opbasis : :obj:`pylops.LinearOperator`
790+ Operator representing the basis functions. If ``None``, the entire
791+ operator used for inversion `Op` is used.
792+ optimal_coeff : :obj:`bool`, optional
793+ Estimate optimal coefficient that minimizes the norm of the residual
794+ :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
795+ directly the value from the inner product
796+ :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
766797 show : :obj:`bool`, optional
767798 Display setup log
768799
@@ -772,17 +803,19 @@ def setup(
772803 self .niter_inner = niter_inner
773804 self .sigma = sigma
774805 self .normalizecols = normalizecols
806+ self .Opbasis = Opbasis if Opbasis is not None else self .Op
807+ self .optimal_coeff = optimal_coeff
775808 self .ncp = get_array_module (y )
776809
777810 # find normalization factor for each column
778811 if self .normalizecols :
779- ncols = self .Op .shape [1 ]
812+ ncols = self .Opbasis .shape [1 ]
780813 self .norms = self .ncp .zeros (ncols )
781814 for icol in range (ncols ):
782- unit = self .ncp .zeros (ncols , dtype = self .Op .dtype )
815+ unit = self .ncp .zeros (ncols , dtype = self .Opbasis .dtype )
783816 unit [icol ] = 1
784- self .norms [icol ] = np .linalg .norm (self .Op .matvec (unit ))
785-
817+ self .norms [icol ] = np .linalg .norm (self .Opbasis .matvec (unit ))
818+ print ( f" { self . norms = } " )
786819 # create variables to track the residual norm and iterations
787820 self .res = self .y .copy ()
788821 self .cost = [
@@ -820,9 +853,9 @@ def step(
820853 """
821854 # compute inner products
822855 cres = self .Op .rmatvec (self .res )
823- cres_abs = np .abs (cres )
824856 if self .normalizecols :
825- cres_abs = cres_abs / self .norms
857+ cres = cres / self .norms
858+ cres_abs = np .abs (cres )
826859 # choose column with max cres
827860 cres_max = np .max (cres_abs )
828861 imax = np .argwhere (cres_abs == cres_max ).ravel ()
@@ -847,11 +880,22 @@ def step(
847880 int (imax ),
848881 ]
849882 )
850- self .res -= Opcol .matvec (cres [imax ] * self .ncp .ones (1 ))
851- if addnew :
852- x .append (cres [imax ])
883+ if not self .optimal_coeff :
884+ # update with coefficient that maximizes the inner product
885+ self .res -= Opcol .matvec (cres [imax ] * self .ncp .ones (1 ))
886+ if addnew :
887+ x .append (cres [imax ])
888+ else :
889+ x [imax_in_cols ] += cres [imax ]
853890 else :
854- x [imax_in_cols ] += cres [imax ]
891+ # find optimal coefficient that minimizes the residual (r - cres * col)
892+ col = Opcol .matvec (self .ncp .ones (1 , dtype = Opcol .dtype ))
893+ cresopt = (Opcol .rmatvec (self .res ) / Opcol .rmatvec (col ))[0 ]
894+ self .res -= Opcol .matvec (cresopt * self .ncp .ones (1 ))
895+ if addnew :
896+ x .append (cresopt )
897+ else :
898+ x [imax_in_cols ] += cresopt
855899 else :
856900 # OMP update
857901 Opcol = self .Op .apply_columns (cols )
@@ -958,6 +1002,8 @@ def solve(
9581002 niter_inner : int = 40 ,
9591003 sigma : float = 1e-4 ,
9601004 normalizecols : bool = False ,
1005+ Opbasis : Optional ["LinearOperator" ] = None ,
1006+ optimal_coeff : bool = False ,
9611007 show : bool = False ,
9621008 itershow : Tuple [int , int , int ] = (10 , 10 , 10 ),
9631009 ) -> Tuple [NDArray , int , NDArray ]:
@@ -980,6 +1026,14 @@ def solve(
9801026 :math:`n_{cols}` times to unit vectors (i.e., containing 1 at
9811027 position j and zero otherwise); use only when the columns of the
9821028 operator are expected to have highly varying norms.
1029+ Opbasis : :obj:`pylops.LinearOperator`
1030+ Operator representing the basis functions. If ``None``, the entire
1031+ operator used for inversion `Op` is used.
1032+ optimal_coeff : :obj:`bool`, optional
1033+ Estimate optimal coefficient that minimizes the norm of the residual
1034+ :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
1035+ directly the value from the inner product
1036+ :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
9831037 show : :obj:`bool`, optional
9841038 Display logs
9851039 itershow : :obj:`tuple`, optional
@@ -1003,6 +1057,8 @@ def solve(
10031057 niter_inner = niter_inner ,
10041058 sigma = sigma ,
10051059 normalizecols = normalizecols ,
1060+ Opbasis = Opbasis ,
1061+ optimal_coeff = optimal_coeff ,
10061062 show = show ,
10071063 )
10081064 x : List [NDArray ] = []
0 commit comments