Skip to content

Commit 14eb466

Browse files
authored
first pass at Gramian training for OLS (#146)
* proof of concept * AbstractMatrix -> AVR * cleaner impl * endline * fix error type * construct kernels if not passed in * add test case for implicit gram construction * last endline * check for isempty instead of iszero
1 parent 0b48318 commit 14eb466

File tree

8 files changed

+81
-12
lines changed

8 files changed

+81
-12
lines changed

src/fit/default.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,22 @@ $SIGNATURES
3333
Fit a generalised linear regression model using an appropriate solver based on
3434
the loss and penalty of the model. A method can, in some cases, be specified.
3535
"""
36-
function fit(glr::GLR, X::AbstractMatrix{<:Real}, y::AVR;
36+
function fit(glr::GLR, X::AbstractMatrix{<:Real}, y::AVR; data=nothing,
3737
solver::Solver=_solver(glr, size(X)))
38-
check_nrows(X, y)
39-
n, p = size(X)
40-
c = getc(glr, y)
41-
return _fit(glr, solver, X, y, scratch(n, p, c, i=glr.fit_intercept))
38+
if hasproperty(solver, :gram) && solver.gram
39+
# interpret X,y as X'X, X'y
40+
data = verify_or_construct_gramian(glr, X, y, data)
41+
p = size(data.XX, 2)
42+
return _fit(glr, solver, data.XX, data.Xy, (; dims=(data.n, p, 0)))
43+
else
44+
check_nrows(X, y)
45+
n, p = size(X)
46+
c = getc(glr, y)
47+
return _fit(glr, solver, X, y, scratch(n, p, c, i=glr.fit_intercept))
48+
end
4249
end
50+
fit(glr::GLR; kwargs...) = fit(glr, zeros((0,0)), zeros((0,)); kwargs...)
51+
4352

4453
function scratch(n, p, c=0; i=false)
4554
p_ = p + Int(i)

src/fit/proxgrad.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Assumption: loss has gradient; penalty has prox e.g.: Lasso
44
# J(θ) = f(θ) + r(θ) where f is smooth
55
function _fit(glr::GLR, solver::ProxGrad, X, y, scratch)
6-
_,p,c = npc(scratch)
6+
n,p,c = npc(scratch)
77
c > 0 && (p *= c)
88
# vector caches + eval cache
99
θ = zeros(p) # θ_k
@@ -19,9 +19,18 @@ function _fit(glr::GLR, solver::ProxGrad, X, y, scratch)
1919
η = 1.0 # stepsize (1/L)
2020
acc = ifelse(solver.accel, 1.0, 0.0) # if 0, no extrapolation (ISTA)
2121
# functions
22-
_f = smooth_objective(glr, X, y; c=c)
23-
_fg! = smooth_fg!(glr, X, y, scratch)
24-
_prox! = prox!(glr, size(X, 1))
22+
_f = if solver.gram
23+
smooth_gram_objective(glr, X, y, n)
24+
else
25+
smooth_objective(glr, X, y; c=c)
26+
end
27+
28+
_fg! = if solver.gram
29+
smooth_gram_fg!(glr, X, y, n)
30+
else
31+
smooth_fg!(glr, X, y, scratch)
32+
end
33+
_prox! = prox!(glr, n)
2534
bt_cond = θ̂ ->
2635
_f(θ̂) > fθ̄ + dot(θ̂ .- θ̄, ∇fθ̄) + sum(abs2.(θ̂ .- θ̄)) / (2η)
2736
# loop-related

src/fit/solvers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Proximal Gradient solver for non-smooth objective functions.
133133
tol::Float64 = 1e-4 # tol relative change of θ i.e. norm(θ-θ_)/norm(θ)
134134
max_inner::Int = 100 # β^max_inner should be > 1e-10
135135
beta::Float64 = 0.8 # in (0, 1); shrinkage in the backtracking step
136+
gram::Bool = false # use precomputed Gramian for lsq where possible
136137
end
137138

138139
FISTA(; kwa...) = ProxGrad(;accel = true, kwa...)

src/glr/d_l2loss.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,12 @@ function smooth_fg!(glr::GLR{L2Loss,<:ENR}, X, y, scratch)
7272
return glr.loss(r) + get_l2(glr.penalty)(view_θ(glr, θ))
7373
end
7474
end
75+
76+
function smooth_gram_fg!(glr::GLR{L2Loss,<:ENR}, XX, Xy, n)
77+
λ = get_penalty_scale_l2(glr, n)
78+
(g, θ) -> begin
79+
_g = XX * θ .- Xy
80+
g .= _g .+ λ .* θ
81+
return θ'*_g + get_l2(glr.penalty)(view_θ(glr, θ))
82+
end
83+
end

src/glr/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export objective, smooth_objective
1+
export objective, smooth_objective, smooth_gram_objective
22

33
# NOTE: RobustLoss are not always everywhere smooth but "smooth-enough".
44
const SmoothLoss = Union{L2Loss, LogisticLoss, MultinomialLoss, RobustLoss}
@@ -37,6 +37,9 @@ Return the smooth part of the objective function of a GLR.
3737
"""
3838
smooth_objective(glr::GLR{<:SmoothLoss,<:ENR}, n) = glr.loss + get_l2(glr.penalty) * ifelse(glr.scale_penalty_with_samples, n, 1.)
3939

40+
smooth_gram_objective(glr::GLR{<:SmoothLoss,<:ENR}, XX, Xy, n) =
41+
θ ->'*XX*θ)/2 -'*Xy) + (get_l2(glr.penalty) * ifelse(glr.scale_penalty_with_samples, n, 1.))(θ)
42+
4043
smooth_objective(::GLR) = @error "Case not implemented yet."
4144

4245
"""

src/mlj/classifiers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ See also [`MultinomialClassifier`](@ref).
6565
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`, `Newton`,
6666
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
6767
68-
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
68+
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxGrad` is the only
6969
option.
7070
7171
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
@@ -142,7 +142,7 @@ See also [`LogisticClassifier`](@ref).
142142
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`,
143143
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
144144
145-
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
145+
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxGrad` is the only
146146
option.
147147
148148
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ function check_nrows(X::AbstractMatrix, y::AbstractVecOrMat)::Nothing
99
throw(DimensionMismatch("`X` and `y` must have the same number of rows."))
1010
end
1111

12+
function verify_or_construct_gramian(glr, X, y, data)
13+
check_nrows(X, y)
14+
isnothing(data) && return (; XX = X'X, Xy = X'y, n = length(y))
15+
16+
!all(hasproperty.(Ref(data), (:XX, :Xy, :n))) && throw(ArgumentError("data must contain XX, Xy, n"))
17+
size(data.XX, 1) != size(data.Xy, 1) && throw(DimensionMismatch("`XX` and Xy` must have the same number of rows."))
18+
!issymmetric(data.XX) && throw(ArgumentError("Input `XX` must be symmetric"))
19+
20+
c = getc(glr, data.Xy)
21+
!iszero(c) && throw(ArgumentError("Categorical loss not supported with Gramian kernel"))
22+
glr.fit_intercept && throw(ArgumentError("Intercept not supported with Gramian kernel"))
23+
24+
if any(!isempty, (X, y))
25+
all((
26+
isapprox(X'X, data.XX; rtol=1e-5),
27+
isapprox(X'y, data.Xy; rtol=1e-5),
28+
length(y) == data.n
29+
)) || throw(ArgumentError("Inputs `X` and `y` do not match inputs `XX` and `Xy`."))
30+
end
31+
32+
return data
33+
end
34+
1235
"""
1336
$SIGNATURES
1437

test/fit/ols-ridge-lasso-elnet.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,18 @@ end
146146
@test nnz(θ_sk) == 8
147147
end
148148
end
149+
150+
@testset "gramian" begin
151+
λ = 0.1
152+
γ = 0.1
153+
enr = ElasticNetRegression(λ, γ; fit_intercept=false,
154+
scale_penalty_with_samples=false)
155+
XX = X'X
156+
Xy = X'y
157+
n = size(X, 1)
158+
θ_fista = fit(enr, X, y; solver=FISTA(max_iter=5000))
159+
θ_gram_explicit = fit(enr; data=(; XX, Xy, n), solver=FISTA(max_iter=5000, gram=true))
160+
θ_gram_implicit = fit(enr, X, y; solver=FISTA(max_iter=5000, gram=true))
161+
@test isapprox(θ_fista, θ_gram_explicit, rtol=1e-5)
162+
@test isapprox(θ_gram_explicit, θ_gram_implicit; rtol=1e-5)
163+
end

0 commit comments

Comments
 (0)