Skip to content

Commit 8e24bd2

Browse files
authored
Support max_iter=0 (#211)
* Support max_ier=0 gracefully * fix gap
1 parent ba245a4 commit 8e24bd2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

celer/lasso_fast.pyx

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ def celer(
5555
if p0 > n_features:
5656
p0 = n_features
5757

58-
cdef int i, j, k, t, idx, startptr, endptr, epoch
58+
cdef int t = 0
59+
cdef int i, j, k, idx, startptr, endptr, epoch
5960
cdef int ws_size = 0
6061
cdef int nnz = 0
61-
cdef floating gap, p_obj, d_obj, highest_d_obj, radius, tol_in
62+
cdef floating gap = -1
63+
cdef floating p_obj, d_obj, highest_d_obj, radius, tol_in
6264
cdef floating gap_in, p_obj_in, d_obj_in, d_obj_accel, highest_d_obj_in
6365
cdef floating tmp, R_sum, tmp_exp, scal
6466
cdef int n_screened = 0
@@ -96,7 +98,9 @@ def celer(
9698

9799
cdef floating norm_y2 = fnrm2(&n_samples, &y[0], &inc) ** 2
98100

99-
cdef floating[:] gaps = np.zeros(max_iter, dtype=dtype)
101+
# max_iter + 1 is to deal with max_iter=0
102+
cdef floating[:] gaps = np.zeros(max_iter + 1, dtype=dtype)
103+
gaps[0] = -1
100104

101105
cdef floating[:] theta_in = np.zeros(n_samples, dtype=dtype)
102106
cdef floating[:] thetacc = np.zeros(n_samples, dtype=dtype)

0 commit comments

Comments
 (0)