Skip to content

Commit 4bcdc23

Browse files
authored
Merge pull request #26 from axsk/weights
Weighted KDE
2 parents d7524ff + 4a3680a commit 4bcdc23

File tree

4 files changed

+66
-33
lines changed

4 files changed

+66
-33
lines changed

src/bivariate.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function default_bandwidth(data::Tuple{RealVector,RealVector})
3030
end
3131

3232
# tabulate data for kde
33-
function tabulate(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Range, Range})
33+
function tabulate(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Range, Range}, weights::Weights = default_weights(data))
3434
xdata, ydata = data
3535
ndata = length(xdata)
3636
length(ydata) == ndata || error("data vectors must be of same length")
@@ -41,17 +41,19 @@ function tabulate(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Range, R
4141

4242
# Set up a grid for discretized data
4343
grid = zeros(Float64, nx, ny)
44-
ainc = 1.0 / (ndata*(sx*sy)^2)
44+
ainc = 1.0 / (sum(weights)*(sx*sy)^2)
4545

4646
# weighted discretization (cf. Jones and Lotwick)
47-
for (x, y) in zip(xdata,ydata)
47+
for i in 1:length(xdata)
48+
x = xdata[i]
49+
y = ydata[i]
4850
kx, ky = searchsortedfirst(xmid,x), searchsortedfirst(ymid,y)
4951
jx, jy = kx-1, ky-1
5052
if 1 <= jx <= nx-1 && 1 <= jy <= ny-1
51-
grid[jx,jy] += (xmid[kx]-x)*(ymid[ky]-y)*ainc
52-
grid[kx,jy] += (x-xmid[jx])*(ymid[ky]-y)*ainc
53-
grid[jx,ky] += (xmid[kx]-x)*(y-ymid[jy])*ainc
54-
grid[kx,ky] += (x-xmid[jx])*(y-ymid[jy])*ainc
53+
grid[jx,jy] += (xmid[kx]-x)*(ymid[ky]-y)*ainc*weights[i]
54+
grid[kx,jy] += (x-xmid[jx])*(ymid[ky]-y)*ainc*weights[i]
55+
grid[jx,ky] += (xmid[kx]-x)*(y-ymid[jy])*ainc*weights[i]
56+
grid[kx,ky] += (x-xmid[jx])*(y-ymid[jy])*ainc*weights[i]
5557
end
5658
end
5759

@@ -87,41 +89,45 @@ end
8789

8890
const BivariateDistribution = Union{MultivariateDistribution,Tuple{UnivariateDistribution,UnivariateDistribution}}
8991

90-
function kde(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Range, Range}, dist::BivariateDistribution)
91-
k = tabulate(data,midpoints)
92+
default_weights(data::Tuple{RealVector, RealVector}) = UniformWeights(length(data[1]))
93+
94+
function kde(data::Tuple{RealVector, RealVector}, weights::Weights, midpoints::Tuple{Range, Range}, dist::BivariateDistribution)
95+
k = tabulate(data, midpoints, weights)
9296
conv(k,dist)
9397
end
9498

9599
function kde(data::Tuple{RealVector, RealVector}, dist::BivariateDistribution;
96100
boundary::Tuple{Tuple{Real,Real}, Tuple{Real,Real}} = (kde_boundary(data[1],std(dist[1])),
97101
kde_boundary(data[2],std(dist[2]))),
98-
npoints::Tuple{Int,Int}=(256,256))
102+
npoints::Tuple{Int,Int}=(256,256),
103+
weights::Weights = default_weights(data))
99104

100105
xmid = kde_range(boundary[1],npoints[1])
101106
ymid = kde_range(boundary[2],npoints[2])
102107

103-
kde(data,(xmid,ymid),dist)
108+
kde(data,weights,(xmid,ymid),dist)
104109
end
105110

106111
function kde(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Range, Range};
107-
bandwidth=default_bandwidth(data), kernel=Normal)
112+
bandwidth=default_bandwidth(data), kernel=Normal, weights::Weights = default_weights(data))
108113

109114
dist = kernel_dist(kernel,bandwidth)
110-
kde(data,midpoints,dist)
115+
kde(data,weights,midpoints,dist)
111116
end
112117

113118
function kde(data::Tuple{RealVector, RealVector};
114119
bandwidth=default_bandwidth(data),
115120
kernel=Normal,
116121
boundary::Tuple{Tuple{Real,Real}, Tuple{Real,Real}} = (kde_boundary(data[1],bandwidth[1]),
117122
kde_boundary(data[2],bandwidth[2])),
118-
npoints::Tuple{Int,Int}=(256,256))
123+
npoints::Tuple{Int,Int}=(256,256),
124+
weights::Weights = default_weights(data))
119125

120126
dist = kernel_dist(kernel,bandwidth)
121127
xmid = kde_range(boundary[1],npoints[1])
122128
ymid = kde_range(boundary[2],npoints[2])
123129

124-
kde(data,(xmid,ymid),dist)
130+
kde(data,weights,(xmid,ymid),dist)
125131
end
126132

127133
# matrix data

src/univariate.jl

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ function default_bandwidth(data::RealVector, alpha::Float64 = 0.9)
3737
return alpha * width * ndata^(-0.2)
3838
end
3939

40+
function default_weights(data::RealVector)
41+
UniformWeights(length(data))
42+
end
43+
4044

4145
# Roughly based on:
4246
# B. W. Silverman (1982) "Algorithm AS 176: Kernel Density Estimation Using
@@ -66,23 +70,32 @@ function kde_range(boundary::Tuple{Real,Real}, npoints::Int)
6670
lo:step:hi
6771
end
6872

73+
immutable UniformWeights{N} end
74+
75+
UniformWeights(n) = UniformWeights{n}()
76+
77+
Base.sum(x::UniformWeights) = 1.
78+
Base.getindex{N}(x::UniformWeights{N}, i) = 1/N
79+
80+
typealias Weights Union{UniformWeights, RealVector, WeightVec}
81+
82+
6983
# tabulate data for kde
70-
function tabulate(data::RealVector, midpoints::Range)
71-
ndata = length(data)
84+
function tabulate(data::RealVector, midpoints::Range, weights::Weights=default_weights(data))
7285
npoints = length(midpoints)
7386
s = step(midpoints)
7487

7588
# Set up a grid for discretized data
7689
grid = zeros(Float64, npoints)
77-
ainc = 1.0 / (ndata*s*s)
90+
ainc = 1.0 / (sum(weights)*s*s)
7891

7992
# weighted discretization (cf. Jones and Lotwick)
80-
for x in data
93+
for (i,x) in enumerate(data)
8194
k = searchsortedfirst(midpoints,x)
8295
j = k-1
8396
if 1 <= j <= npoints-1
84-
grid[j] += (midpoints[k]-x)*ainc
85-
grid[k] += (x-midpoints[j])*ainc
97+
grid[j] += (midpoints[k]-x)*ainc*weights[i]
98+
grid[k] += (x-midpoints[j])*ainc*weights[i]
8699
end
87100
end
88101

@@ -119,30 +132,30 @@ function conv(k::UnivariateKDE, dist::UnivariateDistribution)
119132
end
120133

121134
# main kde interface methods
122-
function kde(data::RealVector, midpoints::Range, dist::UnivariateDistribution)
123-
k = tabulate(data, midpoints)
135+
function kde(data::RealVector, weights::Weights, midpoints::Range, dist::UnivariateDistribution)
136+
k = tabulate(data, midpoints, weights)
124137
conv(k,dist)
125138
end
126139

127140
function kde(data::RealVector, dist::UnivariateDistribution;
128-
boundary::Tuple{Real,Real}=kde_boundary(data,std(dist)), npoints::Int=2048)
141+
boundary::Tuple{Real,Real}=kde_boundary(data,std(dist)), npoints::Int=2048, weights=default_weights(data))
129142

130143
midpoints = kde_range(boundary,npoints)
131-
kde(data,midpoints,dist)
144+
kde(data,weights,midpoints,dist)
132145
end
133146

134147
function kde(data::RealVector, midpoints::Range;
135-
bandwidth=default_bandwidth(data), kernel=Normal)
148+
bandwidth=default_bandwidth(data), kernel=Normal, weights=default_weights(data))
136149
bandwidth > 0.0 || error("Bandwidth must be positive")
137150
dist = kernel_dist(kernel,bandwidth)
138-
kde(data,midpoints,dist)
151+
kde(data,weights,midpoints,dist)
139152
end
140153

141154
function kde(data::RealVector; bandwidth=default_bandwidth(data), kernel=Normal,
142-
npoints::Int=2048, boundary::Tuple{Real,Real}=kde_boundary(data,bandwidth))
155+
npoints::Int=2048, boundary::Tuple{Real,Real}=kde_boundary(data,bandwidth), weights=default_weights(data))
143156
bandwidth > 0.0 || error("Bandwidth must be positive")
144157
dist = kernel_dist(kernel,bandwidth)
145-
kde(data,dist;boundary=boundary,npoints=npoints)
158+
kde(data,dist;boundary=boundary,npoints=npoints,weights=weights)
146159
end
147160

148161
# Select bandwidth using least-squares cross validation, from:
@@ -152,10 +165,11 @@ end
152165

153166
function kde_lscv(data::RealVector, midpoints::Range;
154167
kernel=Normal,
155-
bandwidth_range::Tuple{Real,Real}=(h=default_bandwidth(data); (0.25*h,1.5*h)))
168+
bandwidth_range::Tuple{Real,Real}=(h=default_bandwidth(data); (0.25*h,1.5*h)),
169+
weights=default_weights(data))
156170

157171
ndata = length(data)
158-
k = tabulate(data, midpoints)
172+
k = tabulate(data, midpoints, weights)
159173

160174
# the ft here is K/ba*sqrt(2pi) * u(s), it is K times the Yl in Silverman's book
161175
K = length(k.density)
@@ -194,8 +208,9 @@ function kde_lscv(data::RealVector;
194208
boundary::Tuple{Real,Real}=kde_boundary(data,default_bandwidth(data)),
195209
npoints::Int=2048,
196210
kernel=Normal,
197-
bandwidth_range::Tuple{Real,Real}=(h=default_bandwidth(data); (0.25*h,1.5*h)))
211+
bandwidth_range::Tuple{Real,Real}=(h=default_bandwidth(data); (0.25*h,1.5*h)),
212+
weights::Weights = default_weights(data))
198213

199214
midpoints = kde_range(boundary,npoints)
200-
kde_lscv(data,midpoints; kernel=kernel, bandwidth_range=bandwidth_range)
215+
kde_lscv(data,midpoints; kernel=kernel, bandwidth_range=bandwidth_range, weights=weights)
201216
end

test/bivariate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,11 @@ for X in ([0.0], [0.0,0.0], [0.0,0.5], [-0.5:0.1:0.5;])
5656
@test all(k5.density .>= 0.0)
5757
@test sum(k5.density)*step(k5.x)*step(k5.y) 1.0
5858

59+
k6 = kde([X X],(r,r);kernel=D, weights=ones(X)/length(X))
60+
@test k4.density k6.density
5961
end
6062
end
63+
64+
k1 = kde([0.0 0.0; 1.0 1.0], (r,r), bandwidth=(1,1), weights=[0,1])
65+
k2 = kde([1.0 1.0], (r,r), bandwidth=(1,1))
66+
@test k1.density k2.density

test/univariate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,11 @@ for X in ([0.0], [0.0,0.0], [0.0,0.5], [-0.5:0.1:0.5;])
5353
@test all(k5.density .>= 0.0)
5454
@test sum(k5.density)*step(k5.x) 1.0
5555

56+
k6 = kde(X,r;kernel=D, weights=ones(X)/length(X))
57+
@test k4.density k6.density
5658
end
5759
end
60+
61+
k1 = kde([0.0, 1.], r, bandwidth=1, weights=[0,1])
62+
k2 = kde([1.], r, bandwidth=1)
63+
@test k1.density k2.density

0 commit comments

Comments
 (0)