Skip to content

Commit f939e37

Browse files
committed
change bivariate to use FFTs, make inteface consistent
1 parent c120087 commit f939e37

File tree

5 files changed

+145
-63
lines changed

5 files changed

+145
-63
lines changed

README.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ The main accessor function is `kde`:
1111
```
1212
kde(data)
1313
```
14-
will construct a `UnivariateKDE` object from the data. The optional keyword arguments are
15-
* `endpoints`: the lower and upper limits of the kde as a tuple. Due to the
14+
will construct a `UnivariateKDE` object from the real vector `data`. The optional keyword arguments are
15+
* `boundary`: the lower and upper limits of the kde as a tuple. Due to the
1616
fourier transforms used internally, there should be sufficient spacing to
1717
prevent wrap-around at the boundaries.
1818
* `npoints`: the number of interpolation points to use. The function uses
@@ -34,16 +34,23 @@ allows specifying the internal grid to use. Optional keyword arguments are
3434
kde(data, dist::Distribution)
3535
```
3636
allows specifying the exact distribution to use as the kernel. Optional
37-
keyword arguments are `endpoints` and `npoints`.
37+
keyword arguments are `boundary` and `npoints`.
3838

3939
```
4040
kde(data, midpoints::Range, dist::Distribution)
4141
```
4242
allows specifying both the distribution and grid.
4343

44-
## To do
44+
### Bivariate
4545

46-
* Use an in-place FFT.
47-
* Spline interpolation
48-
* Bias correction
49-
* Improve bandwidth selection
46+
The usage mirrors that of the univariate case, except that `data` is now
47+
either a tuple of vectors
48+
```
49+
kde((xdata, ydata))
50+
```
51+
or a matrix with two columns
52+
```
53+
kde(datamatrix)
54+
```
55+
Similarly, the optional arguments all now take tuple arguments:
56+
e.g. `boundary` now takes a tuple of tuples `((xlo,xhi),(ylo,yhi))`.

TODO

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
* Interface to plotting libs
2+
* Use in-place FFT
3+
* Spline interpolation
4+
* Bias correction
5+
* Improve bandwidth selection (particularly for bivariate)
6+
* Allow use of arbitrary bivariate kernels

src/KDE.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module KDE
33
using StatsBase
44
using Distributions
55

6-
import Base: conv, FloatRange
7-
import StatsBase: RealVector
6+
import Base: conv
7+
import StatsBase: RealVector, RealMatrix
88
import Distributions: twoπ
99

1010
export kde, UnivariateKDE, BivariateKDE

src/bivariate.jl

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,114 @@
11
# Store both grid and density for KDE over R2
2-
immutable BivariateKDE
3-
x::Vector{Float64}
4-
y::Vector{Float64}
2+
immutable BivariateKDE{Rx<:Range,Ry<:Range}
3+
x::Rx
4+
y::Ry
55
density::Matrix{Float64}
66
end
77

8+
function kernel_dist{D<:UnivariateDistribution}(::Type{D},w::(Real,Real))
9+
kernel_dist(D,w[1]), kernel_dist(D,w[2])
10+
end
11+
function kernel_dist{Dx<:UnivariateDistribution,Dy<:UnivariateDistribution}(::Type{(Dx,Dy)},w::(Real,Real))
12+
kernel_dist(Dx,w[1]), kernel_dist(Dy,w[2])
13+
end
814

9-
# Algorithm from MASS Chapter 5 for calculating 2D KDE
10-
function kde(x::RealVector, y::RealVector; width::Float64=NaN, resolution::Int=25)
11-
n = length(x)
15+
# TODO: there are probably better choices.
16+
function default_bandwidth(data::(RealVector,RealVector))
17+
default_bandwidth(data[1]), default_bandwidth(data[2])
18+
end
1219

13-
if length(y) != n
14-
error("x and y must have the same length")
15-
end
20+
# tabulate data for kde
21+
function tabulate(data::(RealVector, RealVector), midpoints::(Range, Range))
22+
xdata, ydata = data
23+
ndata = length(xdata)
24+
length(ydata) == ndata || error("data vectors must be of same length")
1625

17-
if isnan(width)
18-
h1 = kde_bandwidth(x)
19-
h2 = kde_bandwidth(y)
20-
else
21-
h1 = width
22-
h2 = width
26+
xmid, ymid = midpoints
27+
nx, ny = length(xmid), length(ymid)
28+
sx, sy = step(xmid), step(ymid)
29+
30+
# Set up a grid for discretized data
31+
grid = zeros(Float64, nx, ny)
32+
ainc = 1.0 / (ndata*(sx*sy)^2)
33+
34+
# weighted discretization (cf. Jones and Lotwick)
35+
for (x, y) in zip(xdata,ydata)
36+
kx, ky = searchsortedfirst(xmid,x), searchsortedfirst(ymid,y)
37+
jx, jy = kx-1, ky-1
38+
if 1 <= jx <= nx && 1 <= jy <= ny
39+
grid[jx,jy] += (xmid[kx]-x)*(ymid[ky]-y)*ainc
40+
grid[kx,jy] += (x-xmid[jx])*(ymid[ky]-y)*ainc
41+
grid[jx,ky] += (xmid[kx]-x)*(y-ymid[jy])*ainc
42+
grid[kx,ky] += (x-xmid[jx])*(y-ymid[jy])*ainc
43+
end
2344
end
2445

25-
min_x, max_x = extrema(x)
26-
min_y, max_y = extrema(y)
46+
# returns an un-convolved KDE
47+
BivariateKDE(xmid, ymid, grid)
48+
end
49+
50+
# convolution with product distribution of two univariates distributions
51+
function conv(k::BivariateKDE, dist::(UnivariateDistribution,UnivariateDistribution) )
52+
# Transform to Fourier basis
53+
Kx, Ky = size(k.density)
54+
ft = rfft(k.density)
2755

28-
grid_x = [min_x:((max_x - min_x) / (resolution - 1)):max_x]
29-
grid_y = [min_y:((max_y - min_y) / (resolution - 1)):max_y]
56+
distx, disty = dist
3057

31-
mx = Array(Float64, resolution, n)
32-
my = Array(Float64, resolution, n)
33-
for i in 1:resolution
34-
for j in 1:n
35-
mx[i, j] = pdf(Normal(), (grid_x[i] - x[j]) / h1)
36-
my[i, j] = pdf(Normal(), (grid_y[i] - y[j]) / h2)
58+
# Convolve fft with characteristic function of kernel
59+
cx = -twoπ/(step(k.x)*Kx)
60+
cy = -twoπ/(step(k.y)*Ky)
61+
for j = 1:size(ft,2)
62+
for i = 1:size(ft,1)
63+
ft[i,j] *= cf(distx,(i-1)*cx)*cf(disty,min(j-1,Ky-j+1)*cy)
3764
end
3865
end
3966

40-
z = A_mul_Bt(mx, my)
41-
for i in 1:(resolution^2)
42-
z[i] /= (n * h1 * h2)
43-
end
67+
# Invert the Fourier transform to get the KDE
68+
BivariateKDE(k.x, k.y, irfft(ft, Kx))
69+
end
70+
71+
typealias BivariateDistribution Union(MultivariateDistribution,(UnivariateDistribution,UnivariateDistribution))
72+
73+
function kde(data::(RealVector, RealVector), midpoints::(Range, Range), dist::BivariateDistribution)
74+
k = tabulate(data,midpoints)
75+
conv(k,dist)
76+
end
77+
78+
function kde(data::(RealVector, RealVector), dist::BivariateDistribution;
79+
boundary::((Real,Real),(Real,Real)) = (kde_boundary(data[1],std(dist[1])),
80+
kde_boundary(data[2],std(dist[2]))),
81+
npoints::(Int,Int)=(128,128))
82+
83+
xmid = kde_range(boundary[1],npoints[1])
84+
ymid = kde_range(boundary[2],npoints[2])
85+
86+
kde(data,(xmid,ymid),dist)
87+
end
88+
89+
function kde(data::(RealVector, RealVector), midpoints::(Range, Range);
90+
bandwidth=default_bandwidth(data), kernel=Normal)
91+
92+
dist = kernel_dist(kernel,bandwidth)
93+
kde(data,midpoints,dist)
94+
end
95+
96+
function kde(data::(RealVector, RealVector);
97+
bandwidth=default_bandwidth(data),
98+
kernel=Normal,
99+
boundary::((Real,Real),(Real,Real)) = (kde_boundary(data[1],bandwidth[1]),
100+
kde_boundary(data[2],bandwidth[2])),
101+
npoints::(Int,Int)=(128,128))
102+
103+
dist = kernel_dist(kernel,bandwidth)
104+
xmid = kde_range(boundary[1],npoints[1])
105+
ymid = kde_range(boundary[2],npoints[2])
106+
107+
kde(data,(xmid,ymid),dist)
108+
end
44109

45-
return BivariateKDE(grid_x, grid_y, z)
110+
# matrix data
111+
function kde(data::RealMatrix,args...;kwargs...)
112+
size(data,2) == 2 || error("Can only construct KDE from matrices with 2 columns.")
113+
kde((data[:,1],data[:,2]),args...;kwargs...)
46114
end

src/univariate.jl

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernel_dist{D}(::Type{D},w::Real) = (s = w/std(D(0.0,1.0)); D(0.0,s))
1313

1414

1515
# Silverman's rule of thumb for KDE bandwidth selection
16-
function kde_bandwidth(data::Vector{Float64}, alpha::Float64 = 0.9)
16+
function default_bandwidth(data::Vector{Float64}, alpha::Float64 = 0.9)
1717
# Determine length of data
1818
ndata = length(data)
1919

@@ -51,11 +51,19 @@ end
5151

5252
# default kde range
5353
# Should extend enough beyond the data range to avoid cyclic correlation from the FFT
54-
function kde_range(data::RealVector, bandwidth::Real)
54+
function kde_boundary(data::RealVector, bandwidth::Real)
5555
lo, hi = extrema(data)
5656
lo - 4.0*bandwidth, hi + 4.0*bandwidth
5757
end
5858

59+
# convert boundary and npoints to Range object
60+
function kde_range(boundary::(Real,Real), npoints::Int)
61+
lo, hi = boundary
62+
lo < hi || error("boundary (a,b) must have a < b")
63+
64+
step = (hi - lo) / (npoints-1)
65+
lo:step:hi
66+
end
5967

6068
# tabulate data for kde
6169
function tabulate(data::RealVector, midpoints::Range)
@@ -67,7 +75,7 @@ function tabulate(data::RealVector, midpoints::Range)
6775
grid = zeros(Float64, npoints)
6876
ainc = 1.0 / (ndata*s*s)
6977

70-
# weighted discetization (cf. Jones and Lotwick)
78+
# weighted discretization (cf. Jones and Lotwick)
7179
for x in data
7280
k = searchsortedfirst(midpoints,x)
7381
j = k-1
@@ -81,11 +89,9 @@ function tabulate(data::RealVector, midpoints::Range)
8189
UnivariateKDE(midpoints, grid)
8290
end
8391

84-
85-
8692
# convolve raw KDE with kernel
8793
# TODO: use in-place fft
88-
function conv(k::UnivariateKDE, dist::Distribution)
94+
function conv(k::UnivariateKDE, dist::UnivariateDistribution)
8995
# Transform to Fourier basis
9096
K = length(k.density)
9197
ft = rfft(k.density)
@@ -105,34 +111,29 @@ function conv(k::UnivariateKDE, dist::Distribution)
105111
UnivariateKDE(k.x, irfft(ft, K))
106112
end
107113

108-
109-
function kde(data::RealVector, midpoints::Range, dist::Distribution)
114+
# main kde interface methods
115+
function kde(data::RealVector, midpoints::Range, dist::UnivariateDistribution)
110116
k = tabulate(data, midpoints)
111117
conv(k,dist)
112118
end
113119

114-
function kde(data::RealVector, dist::Distribution;
115-
endpoints::(Real,Real)=kde_range(data,std(dist)), npoints::Int=2048)
116-
117-
lo, hi = endpoints
118-
lo < hi || error("endpoints (a,b) must have a < b")
119-
120-
step = (hi - lo) / npoints
121-
midpoints = lo:step:hi
120+
function kde(data::RealVector, dist::UnivariateDistribution;
121+
boundary::(Real,Real)=kde_boundary(data,std(dist)), npoints::Int=2048)
122122

123+
midpoints = kde_range(boundary,npoints)
123124
kde(data,midpoints,dist)
124125
end
125126

126-
function kde(data::RealVector, midpoints::Range;
127-
bandwidth=kde_bandwidth(data), kernel=Normal)
128-
bandwidth <= 0.0 && error("Bandwidth must be positive")
127+
function kde(data::RealVector, midpoints::Range;
128+
bandwidth=default_bandwidth(data), kernel=Normal)
129+
bandwidth > 0.0 || error("Bandwidth must be positive")
129130
dist = kernel_dist(kernel,bandwidth)
130131
kde(data,midpoints,dist)
131132
end
132133

133-
function kde(data::RealVector; bandwidth=kde_bandwidth(data), kernel=Normal,
134-
npoints::Int=2048, endpoints::(Real,Real)=kde_range(data,bandwidth))
135-
bandwidth <= 0.0 && error("Bandwidth must be positive")
134+
function kde(data::RealVector; bandwidth=default_bandwidth(data), kernel=Normal,
135+
npoints::Int=2048, boundary::(Real,Real)=kde_boundary(data,bandwidth))
136+
bandwidth > 0.0 || error("Bandwidth must be positive")
136137
dist = kernel_dist(kernel,bandwidth)
137-
kde(data,dist;endpoints=endpoints,npoints=npoints)
138+
kde(data,dist;boundary=boundary,npoints=npoints)
138139
end

0 commit comments

Comments
 (0)