Skip to content

Commit 1678fd1

Browse files
authored
Adjust eweights calculation to avoid precision issues (#509)
* Adjust eweights calculation to avoid precision issues. The modified function is equivalent to dividing all weights by the largest value. ``` julia> x = [ 0.3 0.42857142857142855 0.6122448979591837 0.8746355685131197 1.249479383590171 1.7849705479859588 2.549957925694227 3.642797036706039 5.203995766722913 7.434279666747019 ] 10-element Array{Float64,1}: 0.3 0.42857142857142855 0.6122448979591837 0.8746355685131197 1.249479383590171 1.7849705479859588 2.549957925694227 3.642797036706039 5.203995766722913 7.434279666747019 julia> x ./ last(x) 10-element Array{Float64,1}: 0.04035360699999998 0.057648009999999965 0.08235429999999996 0.11764899999999996 0.16806999999999994 0.24009999999999995 0.34299999999999997 0.49 0.7 1.0 julia> using StatsBase [ Info: Recompiling stale cache file /Users/rory/.julia/compiled/v1.0/StatsBase/EZjIG.ji for StatsBase [2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91] julia> eweights(1:10, 0.3) 10-element Weights{Float64,Float64,Array{Float64,1}}: 0.04035360699999998 0.05764800999999997 0.08235429999999996 0.11764899999999996 0.16806999999999994 0.24009999999999995 0.3429999999999999 0.48999999999999994 0.7 1.0 ``` * Fix eweight tests. * Add a couple links to the docstring. * Address review comments and properly deprecate unscaled behaviour with a `scaled::DepBool` kwarg. * Rename scaled => scale. * Fixup docstring slightly and don't make (t, λ, n) method a public method. * More docstring updates. * Reword `n` explanation * Bump patch release
1 parent 179c533 commit 1678fd1

File tree

6 files changed

+96
-64
lines changed

6 files changed

+96
-64
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StatsBase"
22
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
33
authors = ["JuliaStats"]
4-
version = "0.33.12"
4+
version = "0.33.13"
55

66
[deps]
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"

src/common.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ const RealFP = Union{Float32, Float64}
2323
# A convenient typealias for deprecating default corrected Bool
2424
const DepBool = Union{Bool, Nothing}
2525

26-
function depcheck(fname::Symbol, b::DepBool)
27-
if b == nothing
28-
msg = "$fname will default to corrected=true in the future. Use corrected=false for previous behaviour."
26+
function depcheck(fname::Symbol, varname::Symbol, b::DepBool)
27+
if b === nothing
28+
msg = "$fname will default to $varname=true in the future. Use $varname=false for previous behaviour."
2929
Base.depwarn(msg, fname)
3030
false
3131
else

src/cov.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ _scattermatm(x::DenseMatrix, wv::AbstractWeights, mean, dims::Int) =
9292
## weighted cov
9393
covm(x::DenseMatrix, mean, w::AbstractWeights, dims::Int=1;
9494
corrected::DepBool=nothing) =
95-
rmul!(scattermat(x, w, mean=mean, dims=dims), varcorrection(w, depcheck(:covm, corrected)))
95+
rmul!(scattermat(x, w, mean=mean, dims=dims), varcorrection(w, depcheck(:covm, :corrected, corrected)))
9696

9797

9898
cov(x::DenseMatrix, w::AbstractWeights, dims::Int=1; corrected::DepBool=nothing) =
99-
covm(x, mean(x, w, dims=dims), w, dims; corrected=depcheck(:cov, corrected))
99+
covm(x, mean(x, w, dims=dims), w, dims; corrected=depcheck(:cov, :corrected, corrected))
100100

101101
function corm(x::DenseMatrix, mean, w::AbstractWeights, vardim::Int=1)
102102
c = covm(x, mean, w, vardim; corrected=false)
@@ -120,7 +120,7 @@ end
120120
function mean_and_cov(x::DenseMatrix, wv::AbstractWeights, dims::Int=1;
121121
corrected::DepBool=nothing)
122122
m = mean(x, wv, dims=dims)
123-
return m, cov(x, wv, dims; corrected=depcheck(:mean_and_cov, corrected))
123+
return m, cov(x, wv, dims; corrected=depcheck(:mean_and_cov, :corrected, corrected))
124124
end
125125

126126
"""

src/moments.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ the population variance is computed by replacing
1919
* `Weights`: `ArgumentError` (bias correction not supported)
2020
"""
2121
varm(v::RealArray, w::AbstractWeights, m::Real; corrected::DepBool=nothing) =
22-
_moment2(v, w, m; corrected=depcheck(:varm, corrected))
22+
_moment2(v, w, m; corrected=depcheck(:varm, :corrected, corrected))
2323

2424
"""
2525
var(x::AbstractArray, w::AbstractWeights, [dim]; mean=nothing, corrected=false)
@@ -40,7 +40,7 @@ replacing ``\\frac{1}{\\sum{w}}`` with a factor dependent on the type of weights
4040
"""
4141
function var(v::RealArray, w::AbstractWeights; mean=nothing,
4242
corrected::DepBool=nothing)
43-
corrected = depcheck(:var, corrected)
43+
corrected = depcheck(:var, :corrected, corrected)
4444

4545
if mean == nothing
4646
varm(v, w, Statistics.mean(v, w); corrected=corrected)
@@ -53,14 +53,14 @@ end
5353

5454
function varm!(R::AbstractArray, A::RealArray, w::AbstractWeights, M::RealArray,
5555
dim::Int; corrected::DepBool=nothing)
56-
corrected = depcheck(:varm!, corrected)
56+
corrected = depcheck(:varm!, :corrected, corrected)
5757
rmul!(_wsum_centralize!(R, abs2, A, convert(Vector, w), M, dim, true),
5858
varcorrection(w, corrected))
5959
end
6060

6161
function var!(R::AbstractArray, A::RealArray, w::AbstractWeights, dims::Int;
6262
mean=nothing, corrected::DepBool=nothing)
63-
corrected = depcheck(:var!, corrected)
63+
corrected = depcheck(:var!, :corrected, corrected)
6464

6565
if mean == 0
6666
varm!(R, A, w, Base.reducedim_initarray(A, dims, 0, eltype(R)), dims;
@@ -84,14 +84,14 @@ end
8484

8585
function varm(A::RealArray, w::AbstractWeights, M::RealArray, dim::Int;
8686
corrected::DepBool=nothing)
87-
corrected = depcheck(:varm, corrected)
87+
corrected = depcheck(:varm, :corrected, corrected)
8888
varm!(similar(A, Float64, Base.reduced_indices(axes(A), dim)), A, w, M,
8989
dim; corrected=corrected)
9090
end
9191

9292
function var(A::RealArray, w::AbstractWeights, dim::Int; mean=nothing,
9393
corrected::DepBool=nothing)
94-
corrected = depcheck(:var, corrected)
94+
corrected = depcheck(:var, :corrected, corrected)
9595
var!(similar(A, Float64, Base.reduced_indices(axes(A), dim)), A, w, dim;
9696
mean=mean, corrected=corrected)
9797
end
@@ -115,7 +115,7 @@ dependent on the type of weights used:
115115
* `Weights`: `ArgumentError` (bias correction not supported)
116116
"""
117117
stdm(v::RealArray, w::AbstractWeights, m::Real; corrected::DepBool=nothing) =
118-
sqrt(varm(v, w, m, corrected=depcheck(:stdm, corrected)))
118+
sqrt(varm(v, w, m, corrected=depcheck(:stdm, :corrected, corrected)))
119119

120120
"""
121121
std(x::AbstractArray, w::AbstractWeights, [dim]; mean=nothing, corrected=false)
@@ -136,18 +136,18 @@ weights used:
136136
* `Weights`: `ArgumentError` (bias correction not supported)
137137
"""
138138
std(v::RealArray, w::AbstractWeights; mean=nothing, corrected::DepBool=nothing) =
139-
sqrt.(var(v, w; mean=mean, corrected=depcheck(:std, corrected)))
139+
sqrt.(var(v, w; mean=mean, corrected=depcheck(:std, :corrected, corrected)))
140140

141141
stdm(v::RealArray, m::RealArray, dim::Int; corrected::DepBool=nothing) =
142-
sqrt!(varm(v, m, dims=dim, corrected=depcheck(:stdm, corrected)))
142+
sqrt!(varm(v, m, dims=dim, corrected=depcheck(:stdm, :corrected, corrected)))
143143

144144
stdm(v::RealArray, w::AbstractWeights, m::RealArray, dim::Int;
145145
corrected::DepBool=nothing) =
146-
sqrt.(varm(v, w, m, dim; corrected=depcheck(:stdm, corrected)))
146+
sqrt.(varm(v, w, m, dim; corrected=depcheck(:stdm, :corrected, corrected)))
147147

148148
std(v::RealArray, w::AbstractWeights, dim::Int; mean=nothing,
149149
corrected::DepBool=nothing) =
150-
sqrt.(var(v, w, dim; mean=mean, corrected=depcheck(:std, corrected)))
150+
sqrt.(var(v, w, dim; mean=mean, corrected=depcheck(:std, :corrected, corrected)))
151151

152152
##### Fused statistics
153153
"""
@@ -183,12 +183,12 @@ end
183183

184184
function mean_and_var(x::RealArray, w::AbstractWeights; corrected::DepBool=nothing)
185185
m = mean(x, w)
186-
v = varm(x, w, m; corrected=depcheck(:mean_and_var, corrected))
186+
v = varm(x, w, m; corrected=depcheck(:mean_and_var, :corrected, corrected))
187187
m, v
188188
end
189189
function mean_and_std(x::RealArray, w::AbstractWeights; corrected::DepBool=nothing)
190190
m = mean(x, w)
191-
s = stdm(x, w, m; corrected=depcheck(:mean_and_std, corrected))
191+
s = stdm(x, w, m; corrected=depcheck(:mean_and_std, :corrected, corrected))
192192
m, s
193193
end
194194

@@ -208,13 +208,13 @@ end
208208
function mean_and_var(x::RealArray, w::AbstractWeights, dims::Int;
209209
corrected::DepBool=nothing)
210210
m = mean(x, w, dims=dims)
211-
v = varm(x, w, m, dims; corrected=depcheck(:mean_and_var, corrected))
211+
v = varm(x, w, m, dims; corrected=depcheck(:mean_and_var, :corrected, corrected))
212212
m, v
213213
end
214214
function mean_and_std(x::RealArray, w::AbstractWeights, dims::Int;
215215
corrected::DepBool=nothing)
216216
m = mean(x, w, dims=dims)
217-
s = stdm(x, w, m, dims; corrected=depcheck(:mean_and_std, corrected))
217+
s = stdm(x, w, m, dims; corrected=depcheck(:mean_and_std, :corrected, corrected))
218218
m, s
219219
end
220220

src/weights.jl

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -205,59 +205,82 @@ pweights(vs::RealArray) = ProbabilityWeights(vec(vs))
205205
end
206206

207207
"""
208-
eweights(t::AbstractVector{<:Integer}, λ::Real)
209-
eweights(t::AbstractVector{T}, r::StepRange{T}, λ::Real) where T
210-
eweights(n::Integer, λ::Real)
208+
eweights(t::AbstractVector{<:Integer}, λ::Real; scale=false)
209+
eweights(t::AbstractVector{T}, r::StepRange{T}, λ::Real; scale=false) where T
210+
eweights(n::Integer, λ::Real; scale=false)
211211
212212
Construct a [`Weights`](@ref) vector which assigns exponentially decreasing weights to past
213-
observations, which in this case corresponds to larger integer values `i` in `t`.
214-
If an integer `n` is provided, weights are generated for values from 1 to `n`
215-
(equivalent to `t = 1:n`).
213+
observations (larger integer values `i` in `t`).
214+
The integer value `n` represents the number of past observations to consider.
215+
`n` defaults to `maximum(t) - minimum(t) + 1` if only `t` is passed in
216+
and the elements are integers, and to `length(r)` if a superset range `r` is also passed in.
217+
If `n` is explicitly passed instead of `t`, `t` defaults to `1:n`.
216218
217-
For each element `i` in `t` the weight value is computed as:
219+
If `scale` is `true` then for each element `i` in `t` the weight value is computed as:
220+
221+
``(1 - λ)^{n - i}``
222+
223+
If `scale` is `false` then each value is computed as:
218224
219225
``λ (1 - λ)^{1 - i}``
220226
221227
# Arguments
222228
223229
- `t::AbstractVector`: temporal indices or timestamps
224230
- `r::StepRange`: a larger range to use when constructing weights from a subset of timestamps
225-
- `n::Integer`: if provided instead of `t`, temporal indices are taken to be `1:n`
231+
- `n::Integer`: the number of past events to consider
226232
- `λ::Real`: a smoothing factor or rate parameter such that ``0 < λ ≤ 1``.
227233
As this value approaches 0, the resulting weights will be almost equal,
228234
while values closer to 1 will put greater weight on the tail elements of the vector.
229235
236+
# Keyword arguments
237+
238+
- `scale::Bool`: Return the weights scaled to between 0 and 1 (default: false)
239+
230240
# Examples
231241
```julia-repl
232-
julia> eweights(1:10, 0.3)
242+
julia> eweights(1:10, 0.3; scale=true)
233243
10-element Weights{Float64,Float64,Array{Float64,1}}:
234-
0.3
235-
0.42857142857142855
236-
0.6122448979591837
237-
0.8746355685131197
238-
1.249479383590171
239-
1.7849705479859588
240-
2.549957925694227
241-
3.642797036706039
242-
5.203995766722913
243-
7.434279666747019
244+
0.04035360699999998
245+
0.05764800999999997
246+
0.08235429999999996
247+
0.11764899999999996
248+
0.16806999999999994
249+
0.24009999999999995
250+
0.3429999999999999
251+
0.48999999999999994
252+
0.7
253+
1.0
244254
```
245-
"""
246-
function eweights(t::AbstractVector{T}, λ::Real) where T<:Integer
255+
# Links
256+
- https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average
257+
- https://en.wikipedia.org/wiki/Exponential_smoothing
258+
"""
259+
function eweights(t::AbstractVector{<:Integer}, λ::Real; kwargs...)
260+
isempty(t) && return Weights(copy(t), 0)
261+
(lo, hi) = extrema(t)
262+
return _eweights(t, λ, hi - lo + 1; kwargs...)
263+
end
264+
265+
eweights(n::Integer, λ::Real; kwargs...) = _eweights(1:n, λ, n; kwargs...)
266+
eweights(t::AbstractVector, r::AbstractRange, λ::Real; kwargs...) =
267+
_eweights(something.(indexin(t, r)), λ, length(r); kwargs...)
268+
269+
function _eweights(t::AbstractVector{<:Integer}, λ::Real, n::Integer; scale::DepBool=nothing)
247270
0 < λ <= 1 || throw(ArgumentError("Smoothing factor must be between 0 and 1"))
271+
f = depcheck(:eweights, :scale, scale) ? _scaled_eweight : _unscaled_eweight
248272

249273
w0 = map(t) do i
250274
i > 0 || throw(ArgumentError("Time indices must be non-zero positive integers"))
251-
λ * (1 - λ)^(1 - i)
275+
f(i, λ, n)
252276
end
253277

254278
s = sum(w0)
255279
Weights(w0, s)
256280
end
257281

258-
eweights(n::Integer, λ::Real) = eweights(1:n, λ)
259-
eweights(t::AbstractVector, r::AbstractRange, λ::Real) =
260-
eweights(something.(indexin(t, r)), λ)
282+
_unscaled_eweight(i, λ, n) = λ * (1 - λ)^(1 - i)
283+
_scaled_eweight(i, λ, n) = (1 - λ)^(n - i)
261284

262285
# NOTE: no variance correction is implemented for exponential weights
263286

@@ -310,7 +333,7 @@ julia> uweights(3)
310333
1
311334
1
312335
1
313-
336+
314337
julia> uweights(Float64, 3)
315338
3-element UnitWeights{Float64}:
316339
1.0

test/weights.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,53 +502,62 @@ end
502502
end
503503

504504
@testset "Exponential Weights" begin
505+
λ = 0.2
505506
@testset "Usage" begin
506-
θ = 5.25
507-
λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method
508-
v =*(1-λ)^(1-i) for i = 1:4]
507+
v = [(1 - λ) ^ (4 - i) for i = 1:4]
509508
w = Weights(v)
510509

511-
@test round.(w, digits=4) == [0.1734, 0.2098, 0.2539, 0.3071]
510+
@test round.(w, digits=4) == [0.512, 0.64, 0.8, 1.0]
512511

513512
@testset "basic" begin
514-
@test eweights(1:4, λ) w
513+
@test eweights(1:4, λ; scale=true) w
515514
end
516515

517516
@testset "1:n" begin
518-
@test eweights(4, λ) w
517+
@test eweights(4, λ; scale=true) w
519518
end
520519

521520
@testset "indexin" begin
522-
v = [λ*(1-λ)^(1-i) for i = 1:10]
521+
v = [(1 - λ) ^ (10 - i) for i = 1:10]
523522

524523
# Test that we should be able to skip indices easily
525-
@test eweights([1, 3, 5, 7], 1:10, λ) Weights(v[[1, 3, 5, 7]])
524+
@test eweights([1, 3, 5, 7], 1:10, λ; scale=true) Weights(v[[1, 3, 5, 7]])
526525

527526
# This should also work with actual time types
528527
t1 = DateTime(2019, 1, 1, 1)
529528
tx = t1 + Hour(7)
530-
tn = DateTime(2019, 1, 2, 1)
529+
tn = DateTime(2019, 1, 1, 10)
531530

532-
@test eweights(t1:Hour(2):tx, t1:Hour(1):tn, λ) Weights(v[[1, 3, 5, 7]])
531+
@test eweights(t1:Hour(2):tx, t1:Hour(1):tn, λ; scale=true) Weights(v[[1, 3, 5, 7]])
533532
end
534533
end
535534

536535
@testset "Empty" begin
537-
@test eweights(0, 0.3) == Weights(Float64[])
538-
@test eweights(1:0, 0.3) == Weights(Float64[])
539-
@test eweights(Int[], 1:10, 0.4) == Weights(Float64[])
536+
@test eweights(0, 0.3; scale=true) == Weights(Float64[])
537+
@test eweights(1:0, 0.3; scale=true) == Weights(Float64[])
538+
@test eweights(Int[], 1:10, 0.4; scale=true) == Weights(Float64[])
540539
end
541540

542541
@testset "Failure Conditions" begin
543542
# λ > 1.0
544-
@test_throws ArgumentError eweights(1, 1.1)
543+
@test_throws ArgumentError eweights(1, 1.1; scale=true)
545544

546545
# time indices are not all positive non-zero integers
547-
@test_throws ArgumentError eweights([0, 1, 2, 3], 0.3)
546+
@test_throws ArgumentError eweights([0, 1, 2, 3], 0.3; scale=true)
548547

549548
# Passing in an array of bools will work because Bool <: Integer,
550549
# but any `false` values will trigger the same argument error as 0.0
551-
@test_throws ArgumentError eweights([true, false, true, true], 0.3)
550+
@test_throws ArgumentError eweights([true, false, true, true], 0.3; scale=true)
551+
end
552+
553+
@testset "scale=false" begin
554+
v =* (1 - λ)^(1 - i) for i = 1:4]
555+
w = Weights(v)
556+
557+
@test round.(w, digits=4) == [0.2, 0.25, 0.3125, 0.3906]
558+
559+
wv = eweights(1:10, λ; scale=false)
560+
@test eweights(1:10, λ; scale=true) wv / maximum(wv)
552561
end
553562
end
554563

0 commit comments

Comments
 (0)