Skip to content

Commit 69f27f2

Browse files
authored
support max_iter = 0 for Logreg with prox newton solver (#213)
1 parent 6aab853 commit 69f27f2

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

celer/PN_logreg.pyx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def newton_celer(
4141
# scale tol for when problem has large or small p_obj
4242
tol *= n_samples * np.log(2)
4343

44-
cdef int i, j, t, k
45-
cdef floating p_obj, d_obj, gap, norm_Xtheta, norm_Xtheta_acc
46-
cdef floating tmp
44+
cdef int t = 0
45+
cdef int i, j, k
46+
cdef floating p_obj, d_obj, norm_Xtheta, norm_Xtheta_acc, tmp
47+
cdef floating gap = -1 # initialized for the warning if max_iter=0
4748
cdef int info_dposv
4849
cdef int ws_size
4950
cdef floating eps_inner = 0.1
@@ -54,7 +55,8 @@ def newton_celer(
5455
cdef int[:] all_features = np.arange(n_features, dtype=np.int32)
5556
cdef floating[:] prios = np.empty(n_features, dtype=dtype)
5657
cdef int[:] WS
57-
cdef floating[:] gaps = np.zeros(max_iter, dtype=dtype)
58+
cdef floating[:] gaps = np.zeros(max(1, max_iter), dtype=dtype)
59+
gaps[0] = gap # support max_iter = 0
5860
cdef floating[:] X_mean = np.zeros(n_features, dtype=dtype)
5961
cdef bint center = False
6062
# TODO support centering

celer/lasso_fast.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def celer(
5959
cdef int i, j, k, idx, startptr, endptr, epoch
6060
cdef int ws_size = 0
6161
cdef int nnz = 0
62-
cdef floating gap = -1
62+
cdef floating gap = -1 # initialized for the warning if max_iter=0
6363
cdef floating p_obj, d_obj, highest_d_obj, radius, tol_in
6464
cdef floating gap_in, p_obj_in, d_obj_in, d_obj_accel, highest_d_obj_in
6565
cdef floating tmp, R_sum, tmp_exp, scal

celer/tests/test_lasso.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Lasso as sklearn_Lasso, lasso_path)
1717

1818
from celer import celer_path
19-
from celer.dropin_sklearn import Lasso, LassoCV
19+
from celer.dropin_sklearn import Lasso, LassoCV, LogisticRegression
2020
from celer.utils.testing import build_dataset
2121

2222

@@ -217,5 +217,20 @@ def test_weights():
217217
np.testing.assert_raises(ValueError, clf1.fit, X=X, y=y)
218218

219219

220+
def test_zero_iter():
221+
X, y = build_dataset(n_samples=30, n_features=50)
222+
223+
# convergence warning is raised bc we return -1 as gap
224+
with warnings.catch_warnings(record=True):
225+
assert_allclose(Lasso(max_iter=0).fit(X, y).coef_, 0)
226+
y = 2 * (y > 0) - 1
227+
assert_allclose(
228+
LogisticRegression(max_iter=0, solver="celer-pn").fit(X, y).coef_,
229+
0)
230+
assert_allclose(
231+
LogisticRegression(max_iter=0, solver="celer").fit(X, y).coef_,
232+
0)
233+
234+
220235
if __name__ == "__main__":
221236
pass

0 commit comments

Comments
 (0)