Skip to content

Commit 0af7754

Browse files
committed
implement UniformWeights
1 parent 3ade534 commit 0af7754

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/univariate.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,20 @@ function kde_range(boundary::(@compat Tuple{Real,Real}), npoints::Int)
6666
lo:step:hi
6767
end
6868

69+
immutable UniformWeights
70+
w
71+
72+
UniformWeights(n) = new(1/n)
73+
end
74+
75+
Base.sum(x::UniformWeights) = 1.
76+
Base.getindex(x::UniformWeights, i) = x.w
77+
78+
typealias Weights Union{UniformWeights, RealVector, WeightVec}
79+
80+
6981
# tabulate data for kde
70-
function tabulate(data::RealVector, weights::RealVector, midpoints::Range)
82+
function tabulate(data::RealVector, weights::Weights, midpoints::Range)
7183
npoints = length(midpoints)
7284
s = step(midpoints)
7385

@@ -90,8 +102,7 @@ function tabulate(data::RealVector, weights::RealVector, midpoints::Range)
90102
end
91103

92104
function tabulate(data::RealVector, midpoints::Range)
93-
weights = ones(data)
94-
tabulate(data, weights, midpoints)
105+
tabulate(data, UniformWeights(length(data)), midpoints)
95106
end
96107

97108
# convolve raw KDE with kernel
@@ -122,33 +133,28 @@ function conv(k::UnivariateKDE, dist::UnivariateDistribution)
122133
UnivariateKDE(k.x, dens)
123134
end
124135

125-
function uniformweights(data)
126-
n = length(data)
127-
fill(1/n, n)
128-
end
129-
130136
# main kde interface methods
131-
function kde(data::RealVector, weights::RealVector, midpoints::Range, dist::UnivariateDistribution)
137+
function kde(data::RealVector, weights::Weights, midpoints::Range, dist::UnivariateDistribution)
132138
k = tabulate(data, weights, midpoints)
133139
conv(k,dist)
134140
end
135141

136142
function kde(data::RealVector, dist::UnivariateDistribution;
137-
boundary::(@compat Tuple{Real,Real})=kde_boundary(data,std(dist)), npoints::Int=2048, weights=uniformweights(data))
143+
boundary::(@compat Tuple{Real,Real})=kde_boundary(data,std(dist)), npoints::Int=2048, weights=UniformWeights(length(data)))
138144

139145
midpoints = kde_range(boundary,npoints)
140146
kde(data,weights,midpoints,dist)
141147
end
142148

143149
function kde(data::RealVector, midpoints::Range;
144-
bandwidth=default_bandwidth(data), kernel=Normal, weights=uniformweights(data))
150+
bandwidth=default_bandwidth(data), kernel=Normal, weights=UniformWeights(length(data)))
145151
bandwidth > 0.0 || error("Bandwidth must be positive")
146152
dist = kernel_dist(kernel,bandwidth)
147153
kde(data,weights,midpoints,dist)
148154
end
149155

150156
function kde(data::RealVector; bandwidth=default_bandwidth(data), kernel=Normal,
151-
npoints::Int=2048, boundary::(@compat Tuple{Real,Real})=kde_boundary(data,bandwidth), weights=uniformweights(data))
157+
npoints::Int=2048, boundary::(@compat Tuple{Real,Real})=kde_boundary(data,bandwidth), weights=UniformWeights(length(data)))
152158
bandwidth > 0.0 || error("Bandwidth must be positive")
153159
dist = kernel_dist(kernel,bandwidth)
154160
kde(data,dist;boundary=boundary,npoints=npoints,weights=weights)

0 commit comments

Comments
 (0)