|
6 | 6 | from numpy.testing import assert_allclose, assert_array_less |
7 | 7 |
|
8 | 8 | from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic, |
9 | | - Quadratic,) |
| 9 | + Quadratic, QuadraticHessian) |
10 | 10 | from skglm.penalties import L1, WeightedL1 |
11 | 11 | from skglm.solvers import AndersonCD, ProxNewton |
12 | 12 | from skglm import GeneralizedLinearEstimator |
@@ -219,5 +219,24 @@ def test_sample_weights(fit_intercept): |
219 | 219 | # np.testing.assert_equal(n_iter, n_iter_overs) |
220 | 220 |
|
221 | 221 |
|
| 222 | +def test_HessianQuadratic(): |
| 223 | + n_samples = 20 |
| 224 | + n_features = 10 |
| 225 | + X, y, _ = make_correlated_data( |
| 226 | + n_samples=n_samples, n_features=n_features, random_state=0) |
| 227 | + A = X.T @ X / n_samples |
| 228 | + b = -X.T @ y / n_samples |
| 229 | + alpha = np.max(np.abs(b)) / 10 |
| 230 | + |
| 231 | + pen = L1(alpha) |
| 232 | + solv = AndersonCD(warm_start=False, verbose=2, fit_intercept=False) |
| 233 | + lasso = GeneralizedLinearEstimator(Quadratic(), pen, solv).fit(X, y) |
| 234 | + qpl1 = GeneralizedLinearEstimator(QuadraticHessian(), pen, solv).fit(A, b) |
| 235 | + |
| 236 | + np.testing.assert_allclose(lasso.coef_, qpl1.coef_) |
| 237 | + # check that it's not just because we got alpha too high and thus 0 coef |
| 238 | + np.testing.assert_array_less(0.1, np.max(np.abs(qpl1.coef_))) |
| 239 | + |
| 240 | + |
222 | 241 | if __name__ == '__main__': |
223 | 242 | pass |
0 commit comments