Skip to content

Commit 64aed96

Browse files
authored
FIX primal computations with infinite weights (#232)
1 parent 32147a2 commit 64aed96

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

celer/cython_utils.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ cdef floating primal_logreg(
129129
for i in range(n_samples):
130130
p_obj += log_1pexp(- y[i] * Xw[i])
131131
for j in range(n_features):
132-
p_obj += alpha * weights[j] * fabs(w[j])
132+
# avoid nan when weights[j] is INFINITY
133+
if w[j]:
134+
p_obj += alpha * weights[j] * fabs(w[j])
133135
return p_obj
134136

135137

@@ -147,7 +149,9 @@ cdef floating primal_lasso(
147149
cdef floating p_obj = 0.
148150
p_obj = fdot(&n_samples, &R[0], &inc, &R[0], &inc) / (2. * n_samples)
149151
for j in range(n_features):
150-
p_obj += alpha * weights[j] * fabs(w[j])
152+
# avoid nan when weights[j] is INFINITY
153+
if w[j]:
154+
p_obj += alpha * weights[j] * fabs(w[j])
151155
return p_obj
152156

153157

celer/group_fast.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ cpdef floating primal_grplasso(
3333
cdef floating p_obj = fnrm2(&n_samples, &R[0], &inc) ** 2 / (2 * n_samples)
3434

3535
for g in range(n_groups):
36-
nrm = 0.
37-
38-
for k in range(grp_ptr[g], grp_ptr[g + 1]):
39-
j = grp_indices[k]
40-
nrm += w[j] ** 2
41-
p_obj += alpha * sqrt(nrm) * weights[g]
36+
if weights[g] != INFINITY:
37+
nrm = 0.
38+
for k in range(grp_ptr[g], grp_ptr[g + 1]):
39+
j = grp_indices[k]
40+
nrm += w[j] ** 2
41+
p_obj += alpha * sqrt(nrm) * weights[g]
4242

4343
return p_obj
4444

0 commit comments

Comments
 (0)