Skip to content

Commit be9881b

Browse files
pievertpapp
andauthored
Remove optim in favor of independent golden section implementation (#84)
* remove Optim dependency * test optimizer * minor reorg * Apply suggestions from code review Co-authored-by: Tamas K. Papp <[email protected]> * add docstring for optimize Co-authored-by: Tamas K. Papp <[email protected]>
1 parent e9d0c1b commit be9881b

File tree

5 files changed

+79
-5
lines changed

5 files changed

+79
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
examples/.ipynb_checkpoints/*
2+
Manifest.toml

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1010
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
11-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1211
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1312

1413
[compat]
1514
Interpolations = "≥ 0.9"
16-
Optim = "≥ 0.16"
1715
julia = "^1"
1816

1917
[extras]

src/KernelDensity.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module KernelDensity
33
using DocStringExtensions: TYPEDEF, FIELDS
44
using StatsBase
55
using Distributions
6-
using Optim
76
using Interpolations
87

98
import StatsBase: RealVector, RealMatrix

src/univariate.jl

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,76 @@ function kde(data::RealVector; bandwidth=default_bandwidth(data), kernel=Normal,
173173
kde(data,dist;boundary=boundary,npoints=npoints,weights=weights)
174174
end
175175

176+
"""
177+
optimize(f, x_lower, x_upper; iterations=1000, rel_tol=nothing, abs_tol=nothing)
178+
179+
Minimize the function `f` in the interval `x_lower..x_upper`, using the
180+
[golden-section search](https://en.wikipedia.org/wiki/Golden-section_search).
181+
Return an approximate minimum `x̃` or error if such approximate minimum cannot be found.
182+
183+
This algorithm assumes that `-f` is unimodal on the interval `x_lower..x_upper`,
184+
that is to say, there exists a unique `x` in `x_lower..x_upper` such that `f` is
185+
decreasing on `x_lower..x` and increasing on `x..x_upper`.
186+
187+
`rel_tol` and `abs_tol` determine the relative and absolute tolerance, that is
188+
to say, the returned value `x̃` should differ from the actual minimum `x` at most
189+
`abs_tol + rel_tol * abs(x̃)`.
190+
If not manually specified, `rel_tol` and `abs_tol` default to `sqrt(eps(T))` and
191+
`eps(T)` respectively, where `T` is the floating point type of `x_lower` and `x_upper`.
192+
193+
`iterations` determines the maximum number of iterations allowed before convergence.
194+
195+
This is a private, unexported function, used internally to select the optimal bandwidth
196+
automatically.
197+
"""
198+
function optimize(f, x_lower, x_upper; iterations=1000, rel_tol=nothing, abs_tol=nothing)
199+
200+
if x_lower > x_upper
201+
error("x_lower must be less than x_upper")
202+
end
203+
204+
T = promote_type(typeof(x_lower/1), typeof(x_upper/1))
205+
rtol = something(rel_tol, sqrt(eps(T)))
206+
atol = something(abs_tol, eps(T))
207+
208+
function midpoint_and_convergence(lower, upper)
209+
midpoint = (lower + upper) / 2
210+
tol = atol + rtol * midpoint
211+
midpoint, (upper - lower) <= 2tol
212+
end
213+
214+
invphi::T = 0.5 * (sqrt(5) - 1)
215+
invphisq::T = 0.5 * (3 - sqrt(5))
216+
217+
a::T, b::T = x_lower, x_upper
218+
h = b - a
219+
c = a + invphisq * h
220+
d = a + invphi * h
221+
222+
fc, fd = f(c), f(d)
223+
224+
for _ in 1:1000
225+
h *= invphi
226+
if fc < fd
227+
m, converged = midpoint_and_convergence(a, d)
228+
converged && return m
229+
b = d
230+
d, fd = c, fc
231+
c = a + invphisq * h
232+
fc = f(c)
233+
else
234+
m, converged = midpoint_and_convergence(c, b)
235+
converged && return m
236+
a = c
237+
c, fc = d, fd
238+
d = a + invphi * h
239+
fd = f(d)
240+
end
241+
end
242+
243+
error("Reached maximum number of iterations without convergence.")
244+
end
245+
176246
# Select bandwidth using least-squares cross validation, from:
177247
# Density Estimation for Statistics and Data Analysis
178248
# B. W. Silverman (1986)
@@ -194,7 +264,7 @@ function kde_lscv(data::RealVector, midpoints::R;
194264
c = -twoπ/(step(k.x)*K)
195265
hlb, hub = bandwidth_range
196266

197-
opt = Optim.optimize(hlb, hub) do h
267+
minimizer = optimize(hlb, hub) do h
198268
dist = kernel_dist(kernel, h)
199269
ψ = 0.0
200270
for j = 1:length(ft2)-1
@@ -204,7 +274,7 @@ function kde_lscv(data::RealVector, midpoints::R;
204274
ψ*step(k.x)/K + pdf(dist,0.0)/ndata
205275
end
206276

207-
dist = kernel_dist(kernel, Optim.minimizer(opt))
277+
dist = kernel_dist(kernel, minimizer)
208278
for j = 0:length(ft)-1
209279
ft[j+1] *= cf(dist, j*c)
210280
end

test/univariate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,9 @@ end
6363
k11 = kde([0.0, 1.], r, bandwidth=1, weights=[0,1])
6464
k12 = kde([1.], r, bandwidth=1)
6565
@test k11.density k12.density
66+
67+
rel_tol = sqrt(eps(Float64))
68+
abs_tol = eps(Float64)
69+
70+
minimizer = @inferred KernelDensity.optimize(x -> (x - 1)^2, 0, 2)
71+
@test abs(minimizer - 1) <= abs_tol + rel_tol * abs(minimizer)

0 commit comments

Comments
 (0)