Skip to content

Commit 254beae

Browse files
committed
feat: added optimal_coeff to omp
1 parent 02d151e commit 254beae

File tree

4 files changed

+127
-25
lines changed

4 files changed

+127
-25
lines changed

examples/plot_ista.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,26 @@
5353
# MP/OMP
5454
eps = 1e-2
5555
maxit = 500
56-
x_mp = pylops.optimization.sparsity.omp(
56+
x_mp, nitermp, costmp = pylops.optimization.sparsity.omp(
5757
Aop, y, niter_outer=maxit, niter_inner=0, sigma=1e-4
58-
)[0]
59-
x_omp = pylops.optimization.sparsity.omp(Aop, y, niter_outer=maxit, sigma=1e-4)[0]
58+
)
59+
x_omp, niteromp, costomp = pylops.optimization.sparsity.omp(
60+
Aop, y, niter_outer=maxit, sigma=1e-4
61+
)
6062

6163
# IRLS
6264
x_irls = pylops.optimization.sparsity.irls(
63-
Aop, y, nouter=50, epsI=1e-5, kind="model", **dict(iter_lim=10)
65+
Aop, y, nouter=maxit, epsI=1e-5, kind="model", **dict(iter_lim=10)
6466
)[0]
6567

6668
# ISTA
67-
x_ista = pylops.optimization.sparsity.ista(
69+
x_ista, niteri, costi = pylops.optimization.sparsity.ista(
6870
Aop,
6971
y,
7072
niter=maxit,
7173
eps=eps,
7274
tol=1e-3,
73-
)[0]
75+
)
7476

7577
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
7678
m, s, b = ax.stem(x, linefmt="k", basefmt="k", markerfmt="ko", label="True")
@@ -87,6 +89,17 @@
8789
ax.legend()
8890
plt.tight_layout()
8991

92+
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
93+
ax.plot(costmp, "c", lw=2, label=r"$x_{MP} (niter=%d)$" % nitermp)
94+
ax.plot(costomp, "g", lw=2, label=r"$x_{OMP} (niter=%d)$" % niteromp)
95+
ax.plot(costi[: nitermp + 5], "r", lw=2, label=r"$x_{ISTA} (niter=%d)$" % niteri)
96+
ax.set_title("Cost function", size=15, fontweight="bold")
97+
ax.set_xlabel("Iteration")
98+
ax.legend()
99+
ax.grid(True, which="both")
100+
plt.tight_layout()
101+
102+
90103
###############################################################################
91104
# We now consider a more interesting problem problem, *wavelet deconvolution*
92105
# from a signal that we assume being composed by a train of spikes convolved

pylops/optimization/cls_sparsity.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -696,21 +696,42 @@ class OMP(Solver):
696696
\DeclareMathOperator*{\argmin}{arg\,min}
697697
\DeclareMathOperator*{\argmax}{arg\,max}
698698
\Lambda_k = \Lambda_{k-1} \cup \left\{\argmax_j
699-
\left|\mathbf{Op}_j^H\,\mathbf{r}_k\right| \right\} \\
699+
\left|\mathbf{Op}^{j H}\,\mathbf{r}_k\right| \right\} \\
700700
\mathbf{x}_k = \argmin_{\mathbf{x}}
701701
\left\|\mathbf{Op}_{\Lambda_k}\,\mathbf{x} - \mathbf{y}\right\|_2^2
702702
703+
where :math:`\mathbf{Op}^j` is the :math:`j`-th column of the operator,
704+
:math:`\mathbf{r}_k` is the residual at iteration :math:`k`, and
705+
:math:`\mathbf{Op}_{\Lambda_k}` is the operator restricted to the columns
706+
in the set :math:`\Lambda_k`.
707+
703708
Note that by choosing ``niter_inner=0`` the basic Matching Pursuit (MP)
704709
algorithm is implemented instead. In other words, instead of solving an
705710
optimization at each iteration to find the best :math:`\mathbf{x}` for the
706-
currently selected basis functions, the vector :math:`\mathbf{x}` is just
707-
updated at the new basis function by taking directly the value from
708-
the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
709-
710-
In this case it is highly recommended to provide a normalized basis
711-
function. If different basis have different norms, the solver is likely
712-
to diverge. Similar observations apply to OMP, even though mild unbalancing
713-
between the basis is generally properly handled.
711+
currently selected basis functions, either the vector :math:`\mathbf{x}`
712+
is just updated at the new basis function by adding the value from
713+
the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k` to the current value
714+
(``optimal_coeff=False``) or the optimal coefficient that minimizes the norm
715+
of the residual :math:`\mathbf{r} - c * \mathbf{Op}^j` is estimated
716+
(``optimal_coeff=True``) and added to the current value.
717+
718+
In the case the MP solver is used, it is highly recommended to provide a
719+
normalized basis function. If different basis have different norms, the
720+
solver is likely to diverge. Similar observations apply to OMP, even
721+
though mild unbalancing between the basis is generally properly handled.
722+
Two possible ways to handle the scenario fo non-normalized basis functions
723+
are:
724+
725+
- Find the normalization factor of the the basis functions before
726+
running the solver (this is done by choosing ``normalizecols=True``);
727+
- Find the optimal coefficient that minimizes the norm of the residual
728+
:math:`\mathbf{r} - c * \mathbf{Op}^j` at every iteration (this is
729+
done by choosing ``optimal_coeff=True``).
730+
731+
Finally, when the operator is a chain of operators, with the rigth-most
732+
representing the basis function, if the operator of the basis function is
733+
provided in the ``Opbasis`` parameter, the solver will use this operator
734+
to find the normalization factor for each column of the operator.
714735
715736
"""
716737

@@ -742,6 +763,8 @@ def setup(
742763
niter_inner: int = 40,
743764
sigma: float = 1e-4,
744765
normalizecols: bool = False,
766+
Opbasis: Optional["LinearOperator"] = None,
767+
optimal_coeff: bool = False,
745768
show: bool = False,
746769
) -> None:
747770
r"""Setup solver
@@ -763,6 +786,14 @@ def setup(
763786
:math:`n_{cols}` times to unit vectors (i.e., containing 1 at
764787
position j and zero otherwise); use only when the columns of the
765788
operator are expected to have highly varying norms.
789+
Opbasis : :obj:`pylops.LinearOperator`
790+
Operator representing the basis functions. If ``None``, the entire
791+
operator used for inversion `Op` is used.
792+
optimal_coeff : :obj:`bool`, optional
793+
Estimate optimal coefficient that minimizes the norm of the residual
794+
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
795+
directly the value from the inner product
796+
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
766797
show : :obj:`bool`, optional
767798
Display setup log
768799
@@ -772,17 +803,19 @@ def setup(
772803
self.niter_inner = niter_inner
773804
self.sigma = sigma
774805
self.normalizecols = normalizecols
806+
self.Opbasis = Opbasis if Opbasis is not None else self.Op
807+
self.optimal_coeff = optimal_coeff
775808
self.ncp = get_array_module(y)
776809

777810
# find normalization factor for each column
778811
if self.normalizecols:
779-
ncols = self.Op.shape[1]
812+
ncols = self.Opbasis.shape[1]
780813
self.norms = self.ncp.zeros(ncols)
781814
for icol in range(ncols):
782-
unit = self.ncp.zeros(ncols, dtype=self.Op.dtype)
815+
unit = self.ncp.zeros(ncols, dtype=self.Opbasis.dtype)
783816
unit[icol] = 1
784-
self.norms[icol] = np.linalg.norm(self.Op.matvec(unit))
785-
817+
self.norms[icol] = np.linalg.norm(self.Opbasis.matvec(unit))
818+
print(f"{self.norms = }")
786819
# create variables to track the residual norm and iterations
787820
self.res = self.y.copy()
788821
self.cost = [
@@ -820,9 +853,9 @@ def step(
820853
"""
821854
# compute inner products
822855
cres = self.Op.rmatvec(self.res)
823-
cres_abs = np.abs(cres)
824856
if self.normalizecols:
825-
cres_abs = cres_abs / self.norms
857+
cres = cres / self.norms
858+
cres_abs = np.abs(cres)
826859
# choose column with max cres
827860
cres_max = np.max(cres_abs)
828861
imax = np.argwhere(cres_abs == cres_max).ravel()
@@ -847,11 +880,22 @@ def step(
847880
int(imax),
848881
]
849882
)
850-
self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1))
851-
if addnew:
852-
x.append(cres[imax])
883+
if not self.optimal_coeff:
884+
# update with coefficient that maximizes the inner product
885+
self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1))
886+
if addnew:
887+
x.append(cres[imax])
888+
else:
889+
x[imax_in_cols] += cres[imax]
853890
else:
854-
x[imax_in_cols] += cres[imax]
891+
# find optimal coefficient that minimizes the residual (r - cres * col)
892+
col = Opcol.matvec(self.ncp.ones(1, dtype=Opcol.dtype))
893+
cresopt = (Opcol.rmatvec(self.res) / Opcol.rmatvec(col))[0]
894+
self.res -= Opcol.matvec(cresopt * self.ncp.ones(1))
895+
if addnew:
896+
x.append(cresopt)
897+
else:
898+
x[imax_in_cols] += cresopt
855899
else:
856900
# OMP update
857901
Opcol = self.Op.apply_columns(cols)
@@ -958,6 +1002,8 @@ def solve(
9581002
niter_inner: int = 40,
9591003
sigma: float = 1e-4,
9601004
normalizecols: bool = False,
1005+
Opbasis: Optional["LinearOperator"] = None,
1006+
optimal_coeff: bool = False,
9611007
show: bool = False,
9621008
itershow: Tuple[int, int, int] = (10, 10, 10),
9631009
) -> Tuple[NDArray, int, NDArray]:
@@ -980,6 +1026,14 @@ def solve(
9801026
:math:`n_{cols}` times to unit vectors (i.e., containing 1 at
9811027
position j and zero otherwise); use only when the columns of the
9821028
operator are expected to have highly varying norms.
1029+
Opbasis : :obj:`pylops.LinearOperator`
1030+
Operator representing the basis functions. If ``None``, the entire
1031+
operator used for inversion `Op` is used.
1032+
optimal_coeff : :obj:`bool`, optional
1033+
Estimate optimal coefficient that minimizes the norm of the residual
1034+
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
1035+
directly the value from the inner product
1036+
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
9831037
show : :obj:`bool`, optional
9841038
Display logs
9851039
itershow : :obj:`tuple`, optional
@@ -1003,6 +1057,8 @@ def solve(
10031057
niter_inner=niter_inner,
10041058
sigma=sigma,
10051059
normalizecols=normalizecols,
1060+
Opbasis=Opbasis,
1061+
optimal_coeff=optimal_coeff,
10061062
show=show,
10071063
)
10081064
x: List[NDArray] = []

pylops/optimization/sparsity.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def omp(
129129
niter_inner: int = 40,
130130
sigma: float = 1e-4,
131131
normalizecols: bool = False,
132+
Opbasis: Optional["LinearOperator"] = None,
133+
optimal_coeff: bool = False,
132134
show: bool = False,
133135
itershow: Tuple[int, int, int] = (10, 10, 10),
134136
callback: Optional[Callable] = None,
@@ -159,6 +161,14 @@ def omp(
159161
:math:`n_{cols}` times to unit vectors (i.e., containing 1 at
160162
position j and zero otherwise); use only when the columns of the
161163
operator are expected to have highly varying norms.
164+
Opbasis : :obj:`pylops.LinearOperator`
165+
Operator representing the basis functions. If not provided, the entire
166+
operator used for inversion `Op` is used.
167+
optimal_coeff : :obj:`bool`, optional
168+
Estimate optimal coefficient that minimizes the norm of the residual
169+
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
170+
directly the value from the inner product
171+
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
162172
show : :obj:`bool`, optional
163173
Display iterations log
164174
itershow : :obj:`tuple`, optional
@@ -200,6 +210,8 @@ def omp(
200210
niter_inner=niter_inner,
201211
sigma=sigma,
202212
normalizecols=normalizecols,
213+
Opbasis=Opbasis,
214+
optimal_coeff=optimal_coeff,
203215
show=show,
204216
itershow=itershow,
205217
)

pytests/test_sparsity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,27 @@ def test_IRLS_model(par):
171171
assert_array_almost_equal(x, xinv, decimal=1)
172172

173173

174+
@pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)])
175+
def test_MP(par):
176+
"""Invert problem with MP"""
177+
np.random.seed(42)
178+
Aop = MatrixMult(np.random.randn(par["ny"], par["nx"]))
179+
180+
x = np.zeros(par["nx"])
181+
x[par["nx"] // 2] = 1
182+
x[3] = 1
183+
x[par["nx"] - 4] = -1
184+
y = Aop * x
185+
186+
sigma = 1e-4
187+
maxit = 100
188+
189+
xinv, _, _ = omp(
190+
Aop, y, maxit, niter_inner=0, optimal_coeff=True, sigma=sigma, show=False
191+
)
192+
assert_array_almost_equal(x, xinv, decimal=1)
193+
194+
174195
@pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)])
175196
def test_OMP(par):
176197
"""Invert problem with OMP"""

0 commit comments

Comments
 (0)