Skip to content

Commit 92a322e

Browse files
Fix bug in Naive assertion and increase coverage
1 parent a6d92f7 commit 92a322e

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

src/models/naive_models.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ typeof_model_elements(model::NaiveModel) = eltype(model.y)
1212

1313
function assert_zero_missing_values(model::NaiveModel)
1414
for i in 1:length(model.y)
15-
if isnan(y[i])
15+
if isnan(model.y[i])
1616
return error("model $(typeof(model)) does not support missing values.)")
1717
end
1818
end
@@ -44,6 +44,7 @@ mutable struct Naive <: NaiveModel
4444
end
4545

4646
function fit!(model::Naive)
47+
assert_zero_missing_values(model)
4748
residuals = model.y[2:end] - model.y[1:end-1]
4849
model.residuals = residuals
4950
model.sigma2 = var(residuals)
@@ -111,6 +112,7 @@ mutable struct SeasonalNaive <: NaiveModel
111112
end
112113

113114
function fit!(model::SeasonalNaive)
115+
assert_zero_missing_values(model)
114116
residuals = model.y[model.seasonal+1:end] - model.y[1:end-model.seasonal]
115117
model.residuals = residuals
116118
model.sigma2 = var(residuals)
@@ -185,6 +187,7 @@ mutable struct ExperimentalSeasonalNaive <: NaiveModel
185187
end
186188

187189
function fit!(model::ExperimentalSeasonalNaive)
190+
assert_zero_missing_values(model)
188191
residuals = model.y[model.seasonal+1:end] - model.y[1:end-model.seasonal]
189192
model.residuals = residuals
190193
model.sigma2 = var(residuals)

test/models/naive_models.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
model = Naive(nile.flow)
55
fit!(model)
6+
StateSpaceModels.get_standard_residuals(model)
67
forec = forecast(model, 10; bootstrap = true)
78
forec = forecast(model, 10)
89
scenarios = simulate_scenarios(model, 10, 1_000)
910
@test monotone_forecast_variance(forec)
1011
test_scenarios_adequacy_with_forecast(forec, scenarios; rtol = 1e-1)
12+
cross_validation(model, 10, 70; n_scenarios=100)
1113

1214
air_passengers = CSV.File(StateSpaceModels.AIR_PASSENGERS) |> DataFrame
1315
log_air_passengers = log.(air_passengers.passengers)
@@ -18,10 +20,19 @@
1820
forec = forecast(model, 60)
1921
scenarios = simulate_scenarios(model, 60, 1_000)
2022
@test monotone_forecast_variance(forec)
23+
StateSpaceModels.reinstantiate(model, model.y)
2124

2225
# Just see if it runs
2326
model = ExperimentalSeasonalNaive(log_air_passengers, 12)
2427
fit!(model)
2528
forec = forecast(model, 60)
2629
scenarios = simulate_scenarios(model, 60, 1_000)
30+
StateSpaceModels.reinstantiate(model, model.y)
31+
32+
using StateSpaceModels
33+
using Test
34+
y = randn(100)
35+
y[10] = NaN
36+
model = Naive(y)
37+
@test_throws ErrorException fit!(model)
2738
end

test/visualization/components.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,24 @@
2020
@test length(rec) == 3
2121
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
2222
@test length(rec) == 3
23+
24+
model = ExponentialSmoothing(log_air_passengers; trend = true, seasonal = 12)
25+
fit!(model)
26+
kf = kalman_filter(model)
27+
ks = kalman_smoother(model)
28+
29+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, kf)
30+
@test length(rec) == 4
31+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
32+
@test length(rec) == 4
33+
34+
finland_fatalities = CSV.File(StateSpaceModels.VEHICLE_FATALITIES) |> DataFrame
35+
log_finland_fatalities = log.(finland_fatalities.ff)
36+
model = ExponentialSmoothing(log_finland_fatalities; trend = true)
37+
fit!(model)
38+
39+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, kf)
40+
@test length(rec) == 3
41+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
42+
@test length(rec) == 3
2343
end

test/visualization/forecast.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,15 @@
1111
scen = simulate_scenarios(model, 12, 100)
1212
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, scen)
1313
@test length(rec) == 2
14+
15+
model = SeasonalNaive(log_air_passengers, 12)
16+
fit!(model)
17+
# forecasting
18+
forec = forecast(model, 12)
19+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, forec)
20+
@test length(rec) == 4
21+
# simulating
22+
scen = simulate_scenarios(model, 12, 100)
23+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, scen)
24+
@test length(rec) == 2
1425
end

0 commit comments

Comments
 (0)