Skip to content

Commit 72b28be

Browse files
authored
Merge pull request #112 from jbrea/lasso
fix scale_penalty_with_samples=true for proxgrad
2 parents 9f1987c + b293b59 commit 72b28be

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

src/fit/proxgrad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function _fit(glr::GLR, solver::ProxGrad, X, y, scratch)
2121
# functions
2222
_f = smooth_objective(glr, X, y; c=c)
2323
_fg! = smooth_fg!(glr, X, y, scratch)
24-
_prox! = prox!(glr)
24+
_prox! = prox!(glr, size(X, 1))
2525
bt_cond = θ̂ ->
2626
_f(θ̂) > fθ̄ + dot(θ̂ .- θ̄, ∇fθ̄) + sum(abs2.(θ̂ .- θ̄)) / (2η)
2727
# loop-related

src/glr/prox.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# prox_{αr}(z) = sign(z)(abs(z) - αλ)₊
1616
# ------------------------------------
1717

18-
function prox!(glr::GLR{<:Loss,<:Union{L1R,CompositePenalty}})
19-
γ = getscale_l1(glr.penalty)
18+
function prox!(glr::GLR{<:Loss,<:Union{L1R,CompositePenalty}}, n)
19+
γ = get_penalty_scale_l1(glr, n)
2020
(p, α, z) -> begin
2121
p .= soft_thresh.(z, α * γ)
2222
glr.fit_intercept && (glr.penalize_intercept || (p[end] = z[end]))

src/loss-penalty/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ getscale_l2(cp::CompositePenalty) = is_elnet(cp) ? cp |> get_l2 |> getscale :
2121

2222
get_penalty_scale(glr, n) = getscale(glr.penalty) * ifelse(glr.scale_penalty_with_samples, float(n), 1.0)
2323
get_penalty_scale_l2(glr, n) = getscale_l2(glr.penalty) * ifelse(glr.scale_penalty_with_samples, float(n), 1.0)
24+
get_penalty_scale_l1(glr, n) = getscale_l1(glr.penalty) * ifelse(glr.scale_penalty_with_samples, float(n), 1.0)

test/fit/ols-ridge-lasso-elnet.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ n, p = 500, 100
5656
@test nnz(θ_fista) == 12 # sparse
5757
@test nnz(θ_ista) == 12
5858

59+
# scale_penalty_with_samples
60+
lr_scaled = LassoRegression/n; fit_intercept=false,
61+
scale_penalty_with_samples = true)
62+
θ_scaled = fit(lr_scaled, X, y)
63+
@test θ_scaled θ_fista
64+
5965
# with intercept
6066
lr1 = LassoRegression(λ, penalize_intercept=true,
6167
scale_penalty_with_samples = false)

0 commit comments

Comments
 (0)