Skip to content

Commit 93c235f

Browse files
committed
rearrange lscv into new kde_lscv method
1 parent eb3ea4e commit 93c235f

File tree

3 files changed

+54
-26
lines changed

3 files changed

+54
-26
lines changed

src/KernelDensity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Base: conv
88
import StatsBase: RealVector, RealMatrix
99
import Distributions: twoπ
1010

11-
export kde, UnivariateKDE, BivariateKDE, bandwidth_lscv
11+
export kde, kde_lscv, UnivariateKDE, BivariateKDE
1212

1313
include("univariate.jl")
1414
include("bivariate.jl")

src/univariate.jl

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -145,35 +145,57 @@ function kde(data::RealVector; bandwidth=default_bandwidth(data), kernel=Normal,
145145
kde(data,dist;boundary=boundary,npoints=npoints)
146146
end
147147

148-
#change the M to some larger value to get better precision of lscv
149-
function bandwidth_lscv(data::RealVector; kernel::DataType=Normal, M=1024)
150-
n=length(data)
151-
h0=default_bandwidth(data)
152-
hlb = h0/sqrt(n)
153-
hub = sqrt(n)*h0
154-
xlb, xub = extrema(data)
155-
midpoints = kde_range((xlb-4*h0, xub+4*h0), M)
148+
# Select bandwidth using least-squares cross validation, from:
149+
# Density Estimation for Statistics and Data Analysis
150+
# B. W. Silverman (1986)
151+
# sections 3.4.3 (pp. 48-52) and 3.5 (pp. 61-66)
156152

153+
function kde_lscv(data::RealVector, midpoints::Range;
154+
kernel=Normal,
155+
bandwidth_range::(Real,Real)=(h=default_bandwidth(data); (0.25*h,1.5*h)))
156+
157+
ndata = length(data)
157158
k = tabulate(data, midpoints)
158-
# the ft here is M/ba*sqrt(2pi) * u(s), it is M times the Yl in Silverman's book
159-
Yl2 = abs2(rfft(k.density)/M)
160159

161-
ba = step(k.x)*M # the range b -a
162-
c = -twoπ/ba
160+
# the ft here is K/ba*sqrt(2pi) * u(s), it is K times the Yl in Silverman's book
161+
K = length(k.density)
162+
ft = rfft(k.density)
163163

164-
return Optim.optimize(h -> lscv(h, Yl2, kernel, c, ba, n,M), hlb, hub).minimum
165-
end
164+
ft2 = abs2(ft)
165+
c = -twoπ/(step(k.x)*K)
166+
hlb, hub = bandwidth_range
167+
168+
opt = Optim.optimize(hlb, hub) do h
169+
dist = kernel_dist(kernel, h)
170+
ψ = 0.0
171+
for j = 1:length(ft2)-1
172+
ks = real(cf(dist, j*c))
173+
ψ += ft2[j+1]*(ks-2.0)*ks
174+
end
175+
ψ*step(k.x) + pdf(dist,0.0)/ndata
176+
end
166177

167-
#Silverman's book use the special case of gaussian kernel. Here the method is generalized to any symmetric kernel
168-
function lscv(bandwidth::Real, Yl2::RealVector, kernel::DataType, c::Real, ba::Real, n::Int,M::Int)
169-
dist = kernel_dist(kernel,bandwidth)
170-
zeta_star = zeros(length(Yl2)-1)
171-
#M is even, length(Yl2) = M/2+1 and Yl2 =[y[l]^2 for l=0 :1: M/2]
172-
for j = 1:length(Yl2)-1
173-
ksl = real(cf(dist,j*c))
174-
zeta_star[j] = Yl2[j+1] * (ksl * ksl - 2 * ksl)
178+
dist = kernel_dist(kernel, opt.minimum)
179+
for j = 0:length(ft)-1
180+
ft[j+1] *= cf(dist, j*c)
175181
end
176-
#Correct the error in silverman's book
177-
#∫ (cf^2 -2cf)u(s)²ds <- ∑(cf^2 - 2cf)*Yl2*ba²/2pi * c
178-
sum(zeta_star) * abs(c)*ba*ba/(2*pi) + pdf(dist, 0.0)/n
182+
183+
dens = irfft(ft, K)
184+
# fix rounding error.
185+
for i = 1:K
186+
dens[i] = max(0.0,dens[i])
187+
end
188+
189+
# Invert the Fourier transform to get the KDE
190+
UnivariateKDE(k.x, dens)
191+
end
192+
193+
function kde_lscv(data::RealVector;
194+
boundary::(Real,Real)=kde_boundary(data,default_bandwidth(data)),
195+
npoints::Int=2048,
196+
kernel=Normal,
197+
bandwidth_range::(Real,Real)=(h=default_bandwidth(data); (0.25*h,1.5*h)))
198+
199+
midpoints = kde_range(boundary,npoints)
200+
kde_lscv(data,midpoints; kernel=kernel, bandwidth_range=bandwidth_range)
179201
end

test/univariate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,11 @@ for X in ([0.0], [0.0,0.0], [0.0,0.5], [-0.5:0.1:0.5])
4747
@test all(k4.density .>= 0.0)
4848
@test_approx_eq sum(k4.density)*step(k4.x) 1.0
4949

50+
k5 = kde_lscv(X)
51+
@test isa(k5,UnivariateKDE)
52+
@test length(k5.density) == length(k5.x)
53+
@test all(k5.density .>= 0.0)
54+
@test_approx_eq sum(k5.density)*step(k5.x) 1.0
55+
5056
end
5157
end

0 commit comments

Comments
 (0)