Skip to content

Commit 477426b

Browse files
Merge pull request #300 from LAMPSPUC/fix_missing_values
Fix missing values
2 parents e9b744b + e0167e5 commit 477426b

File tree

10 files changed

+39
-15
lines changed

10 files changed

+39
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StateSpaceModels"
22
uuid = "99342f36-827c-5390-97c9-d7f9ee765c78"
33
authors = ["raphaelsaavedra <[email protected]>, guilhermebodin <[email protected]>, mariohsouto"]
4-
version = "0.5.20"
4+
version = "0.5.21"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/models/common.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,20 @@ function lagmat(y::Vector{Fl}, k::Int) where Fl
1414
end
1515
return X
1616
end
17+
18+
function variance_of_valid_observations(y::Vector{Fl}) where Fl
19+
return var(filter(!isnan, y))
20+
end
21+
22+
function mean_of_valid_observations(y::Vector{Fl}) where Fl
23+
return mean(filter(!isnan, y))
24+
end
25+
26+
function assert_zero_missing_values(y::Vector)
27+
for el in y
28+
if isnan(el)
29+
return error("This model does not support missing values.")
30+
end
31+
end
32+
return nothing
33+
end

src/models/dar.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ A Dynamic Autorregressive model is defined by:
1919
The dynamic autorregressive model does not have the [`forecast`](@ref) method implemented yet.
2020
If you wish to perform forecasts with this model please open an issue.
2121
22+
!!! warning Missing values
23+
The dynamic autorregressive model currently does not support missing values (`NaN` observations.)
24+
2225
# Example
2326
```jldoctest
2427
julia> model = DAR(randn(100), 2)
@@ -34,6 +37,8 @@ mutable struct DAR <: StateSpaceModel
3437

3538
function DAR(y::Vector{Fl}, lags::Int) where Fl
3639

40+
assert_zero_missing_values(y)
41+
3742
X = lagmat(y, lags)
3843
num_observations = size(X, 1)
3944
first_observations = y[1:lags]
@@ -73,8 +78,8 @@ end
7378

7479
function initial_hyperparameters!(model::DAR)
7580
Fl = typeof_model_elements(model)
76-
observed_variance = var(model.system.y[findall(!isnan, model.system.y)])
77-
observed_mean = mean(model.system.y[findall(!isnan, model.system.y)])
81+
observed_variance = variance_of_valid_observations(model.system.y)
82+
observed_mean = mean_of_valid_observations(model.system.y)
7883
initial_hyperparameters = Dict{String,Fl}(
7984
"sigma2_ε" => observed_variance, "intercept" => observed_mean
8085
)

src/models/exponential_smoothing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ end
135135

136136
function initial_hyperparameters!(model::ExponentialSmoothing)
137137
Fl = typeof_model_elements(model)
138-
observations = model.system.y[findall(!isnan, model.system.y)]
139-
observed_variance = var(observations)
138+
observations = filter(!isnan, model.system.y)
139+
observed_variance = variance_of_valid_observations(model.system.y)
140140
initial_hyperparameters = Dict{String,Fl}(
141141
"sigma2" => observed_variance,
142142
"smoothing_level" => Fl(0.1),

src/models/locallevelcycle.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565

6666
function initial_hyperparameters!(model::LocalLevelCycle)
6767
Fl = typeof_model_elements(model)
68-
observed_variance = var(model.system.y[findall(!isnan, model.system.y)])
68+
observed_variance = variance_of_valid_observations(model.system.y)
6969
initial_hyperparameters = Dict{String,Fl}(
7070
"sigma2_ε" => observed_variance,
7171
"sigma2_η" => observed_variance,

src/models/locallevelexplanatory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868

6969
function initial_hyperparameters!(model::LocalLevelExplanatory)
7070
Fl = typeof_model_elements(model)
71-
observed_variance = var(model.system.y[findall(!isnan, model.system.y)])
71+
observed_variance = variance_of_valid_observations(model.system.y)
7272
initial_hyperparameters = Dict{String,Fl}(
7373
"sigma2_ε" => observed_variance, "sigma2_η" => observed_variance
7474
)

src/models/locallineartrend.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,17 @@ end
5050
function default_filter(model::LocalLinearTrend)
5151
Fl = typeof_model_elements(model)
5252
steadystate_tol = Fl(1e-5)
53-
a1 = zeros(Fl, 2)
54-
P1 = Fl(1e6) .* Matrix{Fl}(I, 2, 2)
55-
return UnivariateKalmanFilter(a1, P1, 2, steadystate_tol)
53+
a1 = zeros(Fl, num_states(model))
54+
P1 = Fl(1e6) .* Matrix{Fl}(I, num_states(model), num_states(model))
55+
return UnivariateKalmanFilter(a1, P1, num_states(model), steadystate_tol)
5656
end
5757

5858
function initial_hyperparameters!(model::LocalLinearTrend)
5959
Fl = typeof_model_elements(model)
60+
observed_variance = variance_of_valid_observations(model.system.y)
6061
initial_hyperparameters = Dict{String,Fl}(
61-
"sigma2_ε" => var(model.system.y),
62-
"sigma2_ξ" => var(model.system.y),
62+
"sigma2_ε" => observed_variance,
63+
"sigma2_ξ" => observed_variance,
6364
"sigma2_ζ" => one(Fl),
6465
)
6566
set_initial_hyperparameters!(model, initial_hyperparameters)

src/models/naive_models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ typeof_model_elements(model::NaiveModel) = eltype(model.y)
1313
function assert_zero_missing_values(model::NaiveModel)
1414
for i in 1:length(model.y)
1515
if isnan(model.y[i])
16-
return error("model $(typeof(model)) does not support missing values.)")
16+
return error("model $(typeof(model)) does not support missing values.")
1717
end
1818
end
1919
return nothing

src/models/unobserved_components.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,7 @@ end
455455

456456
function initial_hyperparameters!(model::UnobservedComponents)
457457
Fl = typeof_model_elements(model)
458-
y = filter(!isnan, model.system.y)
459-
observed_variance = var(y)
458+
observed_variance = variance_of_valid_observations(model.system.y)
460459
# TODO add heuristic for initial hyperparameters
461460
initial_hyperparameters = Dict{String,Fl}(get_names(model) .=> one(Fl))
462461
if model.has_irregular

test/models/dar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# Runned on Python statsmodels
1010
@test loglike(model) -1175.9129 atol = 1e-5 rtol = 1e-5
1111

12+
@test_throws ErrorException DAR(vcat(rand(10), NaN, rand(10)), 2)
13+
1214
# forecasting
1315
@test_throws ErrorException forecast(model, 10)
1416
end

0 commit comments

Comments
 (0)