Skip to content

Commit a0b0464

Browse files
authored
Check that the correct number of weights is passed (#206)
1 parent cff6435 commit a0b0464

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

celer/group_fast.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ cpdef celer_grp(
151151
int[::1] X_indptr, floating[::1] X_mean, floating[:] y, floating alpha,
152152
floating[:] w, floating[:] R, floating[::1] theta,
153153
floating[::1] norms_X_grp, floating tol, int max_iter, int max_epochs,
154-
int gap_freq=10, floating tol_ratio_inner=0.3, int p0=100, bint prune=1, bint use_accel=1,
154+
int gap_freq=10, floating tol_ratio_inner=0.3, int p0=100,
155+
bint prune=1, bint use_accel=1,
155156
bint verbose=0):
156157

157158
pb = LASSO

celer/homotopy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,14 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
183183
X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X)
184184

185185
if weights is None:
186-
weights = np.ones(n_groups) if pb == GRPLASSO else np.ones(n_features)
187-
weights = weights.astype(X.dtype)
186+
weights = np.ones(n_features).astype(X.dtype)
188187
elif (weights <= 0).any():
189188
raise ValueError("0 or negative weights are not supported.")
189+
elif weights.shape[0] != X.shape[1]:
190+
raise ValueError(
191+
"As many weights as features must be passed. "
192+
f"Expected {X.shape[1]}, got {weights.shape[0]}."
193+
)
190194

191195
if alphas is None:
192196
if pb == LASSO:

celer/lasso_fast.pyx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,6 @@ def celer(
217217
scal = dnorm_l1(
218218
is_sparse, theta_in, X, X_data, X_indices, X_indptr,
219219
notin_WS, X_mean, weights, center, positive)
220-
# print(scal)
221-
# print(np.linalg.norm(np.asarray(X)[:, np.asarray(C)].T @ np.asarray(theta_in), ord=np.inf))
222220

223221
if scal > 1. :
224222
tmp = 1. / scal

0 commit comments

Comments
 (0)