@@ -55,10 +55,12 @@ def celer(
5555 if p0 > n_features:
5656 p0 = n_features
5757
58- cdef int i, j, k, t, idx, startptr, endptr, epoch
58+ cdef int t = 0
59+ cdef int i, j, k, idx, startptr, endptr, epoch
5960 cdef int ws_size = 0
6061 cdef int nnz = 0
61- cdef floating gap, p_obj, d_obj, highest_d_obj, radius, tol_in
62+ cdef floating gap = - 1
63+ cdef floating p_obj, d_obj, highest_d_obj, radius, tol_in
6264 cdef floating gap_in, p_obj_in, d_obj_in, d_obj_accel, highest_d_obj_in
6365 cdef floating tmp, R_sum, tmp_exp, scal
6466 cdef int n_screened = 0
@@ -96,7 +98,9 @@ def celer(
9698
9799 cdef floating norm_y2 = fnrm2(& n_samples, & y[0 ], & inc) ** 2
98100
99- cdef floating[:] gaps = np.zeros(max_iter, dtype = dtype)
101+ # max_iter + 1 is to deal with max_iter=0
102+ cdef floating[:] gaps = np.zeros(max_iter + 1 , dtype = dtype)
103+ gaps[0 ] = - 1
100104
101105 cdef floating[:] theta_in = np.zeros(n_samples, dtype = dtype)
102106 cdef floating[:] thetacc = np.zeros(n_samples, dtype = dtype)
0 commit comments