Skip to content

Commit 9bec836

Browse files
authored
Merge pull request #186 from mrava87/feat-nonlin
feat: added fungrad to Nonlinear
2 parents 2f7b7ad + f17bb13 commit 9bec836

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

pyproximal/proximal/Nonlinear.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class Nonlinear(ProxOperator):
1414
- ``fun``: a method evaluating the generic function :math:`f`
1515
- ``grad``: a method evaluating the gradient of the generic function
1616
:math:`f`
17+
- ``fungrad``: a method evaluating both the generic function :math:`f`
18+
and its gradient
1719
- ``optimize``: a method that solves the optimization problem associated
1820
with the proximal operator of :math:`f`. Note that the
1921
``gradprox`` method must be used (instead of ``grad``) as this will
@@ -58,6 +60,12 @@ def _funprox(self, x, tau):
5860
def _gradprox(self, x, tau):
5961
return self.grad(x) + 1. / tau * (x - self.y)
6062

63+
def _fungradprox(self, x, tau):
64+
f, g = self.fungrad(x)
65+
f = f + 1. / (2 * tau) * ((x - self.y) ** 2).sum()
66+
g = g + 1. / tau * (x - self.y)
67+
return f, g
68+
6169
def fun(self, x):
6270
raise NotImplementedError('The method fun has not been implemented.'
6371
'Refer to the documentation for details on '
@@ -66,6 +74,10 @@ def grad(self, x):
6674
raise NotImplementedError('The method grad has not been implemented.'
6775
'Refer to the documentation for details on '
6876
'how to subclass this operator.')
77+
def fungrad(self, x):
78+
raise NotImplementedError('The method grad has not been implemented.'
79+
'Refer to the documentation for details on '
80+
'how to subclass this operator.')
6981
def optimize(self):
7082
raise NotImplementedError('The method optimize has not been implemented.'
7183
'Refer to the documentation for details on '

pytests/test_grads.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_l2(par):
5252
raiseerror=True, atol=1e-3,
5353
verb=False)
5454

55+
5556
@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
5657
def test_lowrank(par):
5758
"""LowRankFactorizedMatrix gradient

0 commit comments

Comments
 (0)