Skip to content

Commit 76cdbdb

Browse files
authored
Merge pull request #14 from arokem/unsorted_fracs
Make sure inputs are sorted. Otherwise, raise error.
2 parents 15ec007 + ef3ddb9 commit 76cdbdb

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

fracridge/fracridge.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from numpy import interp
66
import warnings
7+
import collections
78

89
from sklearn.base import BaseEstimator, MultiOutputMixin
910
from sklearn.utils.validation import (check_X_y, check_array, check_is_fitted,
@@ -80,8 +81,9 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
8081
8182
fracs : float or 1d array, optional
8283
The desired fractions of the parameter vector length, relative to
83-
OLS solution. If 1d array, the shape is (f,).
84-
Default: np.arange(.1, 1.1, .1)
84+
OLS solution. If 1d array, the shape is (f,). This input is required
85+
to be sorted. Otherwise, raises ValueError.
86+
Default: np.arange(.1, 1.1, .1).
8587
8688
jit : bool, optional
8789
Whether to speed up computations by using a just-in-time compiled
@@ -129,7 +131,12 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
129131
if fracs is None:
130132
fracs = np.arange(.1, 1.1, .1)
131133

132-
if not hasattr(fracs, "__len__"):
134+
if hasattr(fracs, "__len__"):
135+
if np.any(np.diff(fracs) < 0):
136+
raise ValueError("The `frac` inputs to the `fracridge` function ",
137+
f"must be sorted. You provided: {fracs}")
138+
139+
else:
133140
fracs = [fracs]
134141
fracs = np.array(fracs)
135142

fracridge/tests/test_fracridge.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,11 @@ def test_FracRidgeRegressorCV(nn, pp, bb, fit_intercept, jit):
130130
assert np.allclose(RR.coef_.T, FRCV.coef_, atol=10e-3)
131131

132132

133+
@pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100)])
134+
@pytest.mark.parametrize("bb", [(1), (2)])
135+
def test_fracridge_unsorted(nn, pp, bb):
136+
X, y, coef_ols, _ = make_data(nn, pp, bb)
137+
fracs = np.array([0.1, 0.8, 1.0, 0.2])
138+
# Frac input needs to be sorted:
139+
with pytest.raises(ValueError):
140+
coef, alpha = fracridge(X, y, fracs=fracs)

0 commit comments

Comments
 (0)