Skip to content

Commit 5309fd9

Browse files
FEAT - Add log-sum penalty (scikit-learn-contrib#127)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 57c6091 commit 5309fd9

File tree

6 files changed

+160
-8
lines changed

6 files changed

+160
-8
lines changed

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Penalties
4040
L1_plus_L2
4141
L2
4242
L2_3
43+
LogSumPenalty
4344
MCPenalty
4445
PositiveConstraint
4546
WeightedL1
@@ -96,7 +97,7 @@ Experimental
9697
:toctree: generated/
9798

9899
IterativeReweightedL1
99-
PDCD_WS
100+
PDCD_WS
100101
Pinball
101102
SqrtQuadratic
102103
SqrtLasso

doc/changes/0.4.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
Version 0.4 (in progress)
44
---------------------------
55
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)
6-
- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`)
6+
- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`)
7+
- Add :ref:`LogSumPenalty <skglm.penalties.LogSumPenalty>` (PR: :gh:`#127`)

skglm/penalties/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .base import BasePenalty
22
from .separable import (
33
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD,
4-
WeightedL1, IndicatorBox, PositiveConstraint
4+
WeightedL1, IndicatorBox, PositiveConstraint, LogSumPenalty
55
)
66
from .block_separable import (
77
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
@@ -14,5 +14,5 @@
1414
BasePenalty,
1515
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD, WeightedL1,
1616
IndicatorBox, PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD,
17-
WeightedGroupL2, SLOPE
17+
WeightedGroupL2, SLOPE, LogSumPenalty
1818
]

skglm/penalties/separable.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from skglm.penalties.base import BasePenalty
66
from skglm.utils.prox_funcs import (
7-
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP,
8-
value_MCP, value_weighted_MCP)
7+
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP,
8+
value_weighted_MCP, prox_log_sum)
99

1010

1111
class L1(BasePenalty):
@@ -607,6 +607,66 @@ def generalized_support(self, w):
607607
return w != 0
608608

609609

610+
class LogSumPenalty(BasePenalty):
611+
"""Log sum penalty.
612+
613+
The penalty value reads
614+
615+
.. math::
616+
617+
"value"(w) = sum_(j=1)^(n_"features") log(1 + abs(w_j) / epsilon)
618+
"""
619+
620+
def __init__(self, alpha, eps):
621+
self.alpha = alpha
622+
self.eps = eps
623+
624+
def get_spec(self):
625+
spec = (
626+
('alpha', float64),
627+
('eps', float64),
628+
)
629+
return spec
630+
631+
def params_to_dict(self):
632+
return dict(alpha=self.alpha, eps=self.eps)
633+
634+
def value(self, w):
635+
"""Compute the value of the log-sum penalty at ``w``."""
636+
return self.alpha * np.sum(np.log(1 + np.abs(w) / self.eps))
637+
638+
def derivative(self, w):
639+
"""Compute the element-wise derivative."""
640+
return np.sign(w) / (np.abs(w) + self.eps)
641+
642+
def prox_1d(self, value, stepsize, j):
643+
"""Compute the proximal operator of the log-sum penalty."""
644+
return prox_log_sum(value, self.alpha * stepsize, self.eps)
645+
646+
def subdiff_distance(self, w, grad, ws):
647+
"""Compute distance of negative gradient to the subdifferential at w."""
648+
subdiff_dist = np.zeros_like(grad)
649+
alpha = self.alpha
650+
eps = self.eps
651+
652+
for idx, j in enumerate(ws):
653+
if w[j] == 0:
654+
# distance of -grad_j to [-alpha/eps, alpha/eps]
655+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - alpha / eps)
656+
else:
657+
# distance of -grad_j to {alpha * sign(w[j]) / (eps + |w[j]|)}
658+
subdiff_dist[idx] = np.abs(
659+
grad[idx] + np.sign(w[j]) * alpha / (eps + np.abs(w[j])))
660+
return subdiff_dist
661+
662+
def is_penalized(self, n_features):
663+
"""Return a binary mask with the penalized features."""
664+
return np.ones(n_features, bool_)
665+
666+
def generalized_support(self, w):
667+
return w != 0
668+
669+
610670
class PositiveConstraint(BasePenalty):
611671
"""Positivity constraint penalty."""
612672

skglm/tests/test_penalties.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from skglm.datafits import Quadratic, QuadraticMultiTask
1010
from skglm.penalties import (
1111
L1, L1_plus_L2, WeightedL1, MCPenalty, SCAD, IndicatorBox, L0_5, L2_3, SLOPE,
12-
PositiveConstraint, L2_1, L2_05, BlockMCPenalty, BlockSCAD)
12+
LogSumPenalty, PositiveConstraint, L2_1, L2_05, BlockMCPenalty, BlockSCAD)
1313
from skglm import GeneralizedLinearEstimator, Lasso
1414
from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA
1515
from skglm.utils.data import make_correlated_data
1616

17+
from skglm.utils.prox_funcs import prox_log_sum, _log_sum_prox_val
18+
1719

1820
n_samples = 20
1921
n_features = 10
@@ -37,7 +39,9 @@
3739
SCAD(alpha=alpha, gamma=4),
3840
IndicatorBox(alpha=alpha),
3941
L0_5(alpha),
40-
L2_3(alpha)]
42+
L2_3(alpha),
43+
LogSumPenalty(alpha=alpha, eps=1e-2)
44+
]
4145

4246
block_penalties = [
4347
L2_1(alpha=alpha), L2_05(alpha=alpha),
@@ -118,5 +122,23 @@ def test_nnls(fit_intercept):
118122
np.testing.assert_allclose(clf.intercept_, reg_nnls.intercept_)
119123

120124

125+
def test_logsum_prox():
126+
alpha = 1.
127+
128+
grid_z = np.linspace(-2, 2, num=10)
129+
grid_test = np.linspace(-5, 5, num=100)
130+
grid_eps = np.linspace(0, 5, num=10 + 1)[1:]
131+
132+
for z, eps in zip(grid_z, grid_eps):
133+
prox = prox_log_sum(z, alpha, eps)
134+
obj_at_prox = _log_sum_prox_val(prox, z, alpha, eps)
135+
136+
is_lowest = all(
137+
obj_at_prox <= _log_sum_prox_val(x, z, alpha, eps) for x in grid_test
138+
)
139+
140+
np.testing.assert_equal(is_lowest, True)
141+
142+
121143
if __name__ == "__main__":
122144
pass

skglm/utils/prox_funcs.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,71 @@ def prox_SLOPE(z, alphas):
204204
x[i] = d
205205

206206
return x
207+
208+
209+
@njit
210+
def prox_log_sum(x, alpha, eps):
211+
"""Proximal operator of log-sum penalty.
212+
213+
Parameters
214+
----------
215+
x : float
216+
Coefficient.
217+
218+
alpha : float
219+
Regularization hyperparameter.
220+
221+
eps : float
222+
Curvature hyperparameter.
223+
224+
References
225+
----------
226+
.. [1] Ashley Prater-Bennette, Lixin Shen, Erin E. Tripp
227+
The Proximity Operator of the Log-Sum Penalty (2021)
228+
"""
229+
if np.sqrt(alpha) <= eps:
230+
if abs(x) <= alpha / eps:
231+
return 0.
232+
else:
233+
return np.sign(x) * _r2(abs(x), alpha, eps)
234+
else:
235+
a = 2 * np.sqrt(alpha) - eps
236+
b = alpha / eps
237+
# f is continuous and f(a) * f(b) < 0, the root can be found by bisection
238+
x_star = _find_root_by_bisection(a, b, alpha, eps)
239+
if abs(x) <= x_star:
240+
return 0.
241+
else:
242+
return np.sign(x) * _r2(abs(x), alpha, eps)
243+
244+
245+
@njit
246+
def _r2(x, alpha, eps):
247+
# compute r2 as in (eq. 7), ref [1] in `prox_log_sum`
248+
return (x - eps) / 2. + np.sqrt(((x + eps) ** 2) / 4 - alpha)
249+
250+
251+
@njit
252+
def _log_sum_prox_val(x, z, alpha, eps):
253+
# prox objective of log-sum `log(1 + abs(x) / eps)`
254+
return ((x - z) ** 2) / (2 * alpha) + np.log1p(np.abs(x) / eps)
255+
256+
257+
@njit
258+
def _r(x, alpha, eps):
259+
# compute r as defined in (eq. 9), ref [1] in `prox_log_sum`
260+
r_z = _log_sum_prox_val(_r2(x, alpha, eps), x, alpha, eps)
261+
r_0 = _log_sum_prox_val(0, x, alpha, eps)
262+
return r_z - r_0
263+
264+
265+
@njit
266+
def _find_root_by_bisection(a, b, alpha, eps, tol=1e-8):
267+
# find root of function func in interval [a, b] by bisection."""
268+
while b - a > tol:
269+
c = (a + b) / 2.
270+
if _r(a, alpha, eps) * _r(c, alpha, eps) < 0:
271+
b = c
272+
else:
273+
a = c
274+
return c

0 commit comments

Comments
 (0)