Skip to content

Commit 0e19298

Browse files
authored
Merge pull request #111 from jbrea/dev
fix penalize_intercept = false
2 parents ca47d0d + d4c7a7f commit 0e19298

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

src/fit/analytical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function _fit(glr::GLR{L2Loss,<:L2R}, solver::Analytical, X, y, scratch)
2424
return augment_X(X, glr.fit_intercept) \ y
2525
else
2626
# Ridge case -- form the Hat Matrix then solve
27-
H = form_XtX(X, glr.fit_intercept, λ)
27+
H = form_XtX(X, glr.fit_intercept, λ, glr.penalize_intercept)
2828
b = X'y
2929
glr.fit_intercept && (b = vcat(b, sum(y)))
3030
return cholesky!(H) \ b

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ $SIGNATURES
101101
102102
Form (X'X) while being memory aware (assuming p ≪ n).
103103
"""
104-
function form_XtX(X, fit_intercept, lambda=0)
104+
function form_XtX(X, fit_intercept, lambda = 0, penalize_intercept = true)
105105
if fit_intercept
106106
n, p = size(X)
107107
XtX = zeros(p+1, p+1)
@@ -116,7 +116,7 @@ function form_XtX(X, fit_intercept, lambda=0)
116116
end
117117
if !iszero(lambda)
118118
λ = convert(eltype(XtX), lambda)
119-
@inbounds for i in 1:size(XtX, 1)
119+
@inbounds for i in 1:size(XtX, 1) + fit_intercept * (penalize_intercept - 1)
120120
XtX[i,i] += λ
121121
end
122122
end

test/fit/ols-ridge-lasso-elnet.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@ end
2222
scale_penalty_with_samples = false)
2323
rr1 = RidgeRegression(λ, penalize_intercept=true,
2424
scale_penalty_with_samples = false)
25+
rr2 = RidgeRegression(λ, fit_intercept = true,
26+
penalize_intercept = false,
27+
scale_penalty_with_samples = false)
2528

2629
β_ref = (X'X + λ*I) \ (X'y)
2730
β_ref1 = (X1'X1 + λ*I) \ (X1'y1)
31+
β_ref2 = (X1'X1 + diagm(push!(fill(λ, p), 0))) \ (X1'y1)
2832
@test β_ref fit(rr, X, y)
2933
@test β_ref1 fit(rr1, X, y1)
34+
@test β_ref2 fit(rr2, X, y1)
3035

3136
β_cg = fit(rr, X, y; solver=CG())
3237
β_cg1 = fit(rr1, X, y1; solver=CG())

0 commit comments

Comments
 (0)