Skip to content

Commit 6cda605

Browse files
Fix an error in ProbWeightRegression when cvxpy fails to converge (#768)
* fix: raise clear error in `ProbWeightRegression` when cvxpy fails to solve When input features have very large values, cvxpy may fail to find a solution, leaving coefficients as None. This caused a confusing error during predict (matrix multiplication with None). Now a ValueError is raised at fit time with the solver status and a suggestion to apply feature scaling. * Update sklego/linear_model.py Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>
1 parent c200c79 commit 6cda605

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

sklego/linear_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ def fit(self, X, y):
245245
# Solve the problem.
246246
prob = cp.Problem(objective, constraints)
247247
prob.solve()
248+
249+
if prob.status != "optimal":
250+
raise ValueError(
251+
f"cvxpy could not find a solution (status: {prob.status}).\n"
252+
"Consider feature scaling (e.g. StandardScaler) before fitting the model."
253+
)
254+
248255
self.coef_ = betas.value
249256
self.n_features_in_ = X.shape[1]
250257

tests/test_estimators/test_probweight_regression.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,14 @@ def test_shape_trained_model(random_xy_dataset_regr):
1717
mod_no_intercept = ProbWeightRegression()
1818
assert mod_no_intercept.fit(X, y).coef_.shape == (X.shape[1],)
1919
np.testing.assert_approx_equal(mod_no_intercept.fit(X, y).coef_.sum(), 1.0, significant=4)
20+
21+
22+
def test_raises_on_unsolvable_problem():
23+
"""Test that a clear error is raised when cvxpy cannot find a solution."""
24+
np.random.seed(42)
25+
X = np.random.randn(10, 5) * 1e15
26+
y = X @ np.array([2, -1, 3, 0.5, -2]) + np.random.randn(10) * 100
27+
28+
model = ProbWeightRegression()
29+
with pytest.raises(ValueError, match="cvxpy could not find a solution"):
30+
model.fit(X, y)

0 commit comments

Comments
 (0)