Skip to content

Commit ab6b0de

Browse files
Merge pull request #272 from LAMPSPUC/cross_validation
Change from backtest to cross validation
2 parents e15df3b + 92a322e commit ab6b0de

File tree

11 files changed

+74
-29
lines changed

11 files changed

+74
-29
lines changed

docs/src/examples.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ smoother_output = kalman_smoother(model)
193193
plot(df.date, get_smoothed_state(smoother_output)[:, 2], label = "slope")
194194
```
195195

196-
## Backtest the forecasts of a model
196+
## Cross validation of the forecasts of a model
197197

198198
Often times users would like to compare the forecasting skill of different models. The function
199-
[`backtest`](@ref) makes it easy to make a rolling window scheme of estimations and forecasts
199+
[`cross_validation`](@ref) makes it easy to make a rolling window scheme of estimations and forecasts
200200
that allow users to track each model forecasting skill per lead time. A simple plot recipe is
201201
implemented to help users to interpret the results easily.
202202

@@ -210,6 +210,6 @@ using CSV, DataFrames
210210
air_passengers = CSV.File(StateSpaceModels.AIR_PASSENGERS) |> DataFrame
211211
log_air_passengers = log.(air_passengers.passengers)
212212
model = BasicStructural(log_air_passengers, 12)
213-
b = backtest(model, 24, 50)
213+
b = cross_validation(model, 24, 50)
214214
plot(b, "Basic structural model")
215215
```

docs/src/manual.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ isfitted
154154

155155
StateSpaceModels.jl has functions to make forecasts of the predictive densities multiple steps ahead and to
156156
simulate scenarios based on those forecasts. The package also has a functions to benchmark the model forecasts
157-
using backtest techniques.
157+
using cross_validation techniques.
158158

159159
```@docs
160160
forecast
161161
simulate_scenarios
162-
backtest
162+
cross_validation
163163
```
164164

165165
## Visualization

src/StateSpaceModels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ include("optimizers.jl")
3333
include("fit.jl")
3434
include("prints.jl")
3535
include("forecast.jl")
36-
include("backtest.jl")
36+
include("cross_validation.jl")
3737

3838
include("models/common.jl")
3939
include("models/locallevel.jl")
@@ -52,7 +52,7 @@ include("models/dar.jl")
5252

5353
include("visualization/forecast.jl")
5454
include("visualization/components.jl")
55-
include("visualization/backtest.jl")
55+
include("visualization/cross_validation.jl")
5656
include("visualization/diagnostics.jl")
5757

5858
# Exported types and structs
@@ -82,7 +82,7 @@ export UnobservedComponents
8282

8383
# Exported functions
8484
export auto_ets
85-
export backtest
85+
export cross_validation
8686
export constrain_box!
8787
export constrain_identity!
8888
export constrain_variance!

src/backtest.jl renamed to src/cross_validation.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
struct Backtest{Fl <: AbstractFloat}
1+
struct CrossValidation{Fl <: AbstractFloat}
22
abs_errors::Matrix{Fl}
33
mae::Vector{Fl}
44
crps_scores::Matrix{Fl}
55
mean_crps::Vector{Fl}
6-
function Backtest{Fl}(n::Int, steps_ahead::Int) where Fl
6+
function CrossValidation{Fl}(n::Int, steps_ahead::Int) where Fl
77
abs_errors = Matrix{Fl}(undef, steps_ahead, n)
88
crps_scores = Matrix{Fl}(undef, steps_ahead, n)
99
mae = Vector{Fl}(undef, steps_ahead)
@@ -34,27 +34,27 @@ function evaluate_crps(y::Vector{Fl}, scenarios::Matrix{Fl}) where {Fl}
3434
end
3535

3636
"""
37-
backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
37+
cross_validation(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
3838
n_scenarios::Int = 10_000,
3939
filter::KalmanFilter=default_filter(model),
4040
optimizer::Optimizer=default_optimizer(model)) where Fl
4141
4242
Makes rolling window estimating and forecasting to benchmark the forecasting skill of the model
4343
in for different time periods and different lead times. The function returns a struct with the MAE
44-
and mean CRPS per lead time. See more on [Backtest the forecasts of a model](@ref)
44+
and mean CRPS per lead time. See more on [CrossValidation the forecasts of a model](@ref)
4545
4646
# References
4747
* DTU course "31761 - Renewables in electricity markets" available on youtube https://www.youtube.com/watch?v=Ffo8XilZAZw&t=556s
4848
"""
49-
function backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
49+
function cross_validation(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
5050
n_scenarios::Int = 10_000,
5151
filter::KalmanFilter=default_filter(model),
5252
optimizer::Optimizer=default_optimizer(model))
5353
Fl = typeof_model_elements(model)
5454
num_mle = length(model.system.y) - start_idx - steps_ahead
55-
b = Backtest{Fl}(num_mle, steps_ahead)
55+
cv = CrossValidation{Fl}(num_mle, steps_ahead)
5656
for i in 1:num_mle
57-
println("Backtest: step $i of $num_mle")
57+
println("CrossValidation: step $i of $num_mle")
5858
y_to_fit = model.system.y[1:start_idx - 1 + i]
5959
y_to_verify = model.system.y[start_idx + i:start_idx - 1 + i + steps_ahead]
6060
model_to_fit = reinstantiate(model, y_to_fit)
@@ -64,12 +64,12 @@ function backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
6464
expected_value_vector = forecast_expected_value(forec)[:]
6565
abs_errors = evaluate_abs_error(y_to_verify, expected_value_vector)
6666
crps_scores = evaluate_crps(y_to_verify, scenarios[:, 1, :])
67-
b.abs_errors[:, i] = abs_errors
68-
b.crps_scores[:, i] = crps_scores
67+
cv.abs_errors[:, i] = abs_errors
68+
cv.crps_scores[:, i] = crps_scores
6969
end
7070
for i in 1:steps_ahead
71-
b.mae[i] = mean(b.abs_errors[i, :])
72-
b.mean_crps[i] = mean(b.crps_scores[i, :])
71+
cv.mae[i] = mean(cv.abs_errors[i, :])
72+
cv.mean_crps[i] = mean(cv.crps_scores[i, :])
7373
end
74-
return b
74+
return cv
7575
end

src/models/naive_models.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ typeof_model_elements(model::NaiveModel) = eltype(model.y)
1212

1313
function assert_zero_missing_values(model::NaiveModel)
1414
for i in 1:length(model.y)
15-
if isnan(y[i])
15+
if isnan(model.y[i])
1616
return error("model $(typeof(model)) does not support missing values.)")
1717
end
1818
end
@@ -44,6 +44,7 @@ mutable struct Naive <: NaiveModel
4444
end
4545

4646
function fit!(model::Naive)
47+
assert_zero_missing_values(model)
4748
residuals = model.y[2:end] - model.y[1:end-1]
4849
model.residuals = residuals
4950
model.sigma2 = var(residuals)
@@ -111,6 +112,7 @@ mutable struct SeasonalNaive <: NaiveModel
111112
end
112113

113114
function fit!(model::SeasonalNaive)
115+
assert_zero_missing_values(model)
114116
residuals = model.y[model.seasonal+1:end] - model.y[1:end-model.seasonal]
115117
model.residuals = residuals
116118
model.sigma2 = var(residuals)
@@ -185,6 +187,7 @@ mutable struct ExperimentalSeasonalNaive <: NaiveModel
185187
end
186188

187189
function fit!(model::ExperimentalSeasonalNaive)
190+
assert_zero_missing_values(model)
188191
residuals = model.y[model.seasonal+1:end] - model.y[1:end-model.seasonal]
189192
model.residuals = residuals
190193
model.sigma2 = var(residuals)
@@ -226,13 +229,13 @@ function reinstantiate(model::ExperimentalSeasonalNaive, y::Vector{<:Real})
226229
return ExperimentalSeasonalNaive(y, model.seasonal; S = model.S)
227230
end
228231

229-
function backtest(model::NaiveModel, steps_ahead::Int, start_idx::Int;
232+
function cross_validation(model::NaiveModel, steps_ahead::Int, start_idx::Int;
230233
n_scenarios::Int = 10_000)
231234
Fl = typeof_model_elements(model)
232235
num_fits = length(model.y) - start_idx - steps_ahead
233-
b = Backtest{Fl}(num_fits, steps_ahead)
236+
b = CrossValidation{Fl}(num_fits, steps_ahead)
234237
for i in 1:num_fits
235-
println("Backtest: step $i of $num_fits")
238+
println("CrossValidation: step $i of $num_fits")
236239
y_to_fit = model.y[1:start_idx - 1 + i]
237240
y_to_verify = model.y[start_idx + i:start_idx - 1 + i + steps_ahead]
238241
model_to_fit = reinstantiate(model, y_to_fit)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
@recipe function f(b::Backtest, name::String)
1+
@recipe function f(cv::CrossValidation, name::String)
22
xguide := "lead times"
33
@series begin
44
seriestype := :path
55
label := "MAE " * name
66
marker := :circle
7-
b.mae
7+
cv.mae
88
end
99
@series begin
1010
seriestype := :path
1111
label := "Mean CRPS " * name
1212
marker := :circle
13-
b.mean_crps
13+
cv.mean_crps
1414
end
1515
end

test/models/naive_models.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
model = Naive(nile.flow)
55
fit!(model)
6+
StateSpaceModels.get_standard_residuals(model)
67
forec = forecast(model, 10; bootstrap = true)
78
forec = forecast(model, 10)
89
scenarios = simulate_scenarios(model, 10, 1_000)
910
@test monotone_forecast_variance(forec)
1011
test_scenarios_adequacy_with_forecast(forec, scenarios; rtol = 1e-1)
12+
cross_validation(model, 10, 70; n_scenarios=100)
1113

1214
air_passengers = CSV.File(StateSpaceModels.AIR_PASSENGERS) |> DataFrame
1315
log_air_passengers = log.(air_passengers.passengers)
@@ -18,10 +20,19 @@
1820
forec = forecast(model, 60)
1921
scenarios = simulate_scenarios(model, 60, 1_000)
2022
@test monotone_forecast_variance(forec)
23+
StateSpaceModels.reinstantiate(model, model.y)
2124

2225
# Just see if it runs
2326
model = ExperimentalSeasonalNaive(log_air_passengers, 12)
2427
fit!(model)
2528
forec = forecast(model, 60)
2629
scenarios = simulate_scenarios(model, 60, 1_000)
30+
StateSpaceModels.reinstantiate(model, model.y)
31+
32+
using StateSpaceModels
33+
using Test
34+
y = randn(100)
35+
y[10] = NaN
36+
model = Naive(y)
37+
@test_throws ErrorException fit!(model)
2738
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ include("models/dar.jl")
3131
# Visualization
3232
include("visualization/forecast.jl")
3333
include("visualization/components.jl")
34-
include("visualization/backtest.jl")
34+
include("visualization/cross_validation.jl")
3535
include("visualization/diagnostics.jl")

test/visualization/components.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,24 @@
2020
@test length(rec) == 3
2121
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
2222
@test length(rec) == 3
23+
24+
model = ExponentialSmoothing(log_air_passengers; trend = true, seasonal = 12)
25+
fit!(model)
26+
kf = kalman_filter(model)
27+
ks = kalman_smoother(model)
28+
29+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, kf)
30+
@test length(rec) == 4
31+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
32+
@test length(rec) == 4
33+
34+
finland_fatalities = CSV.File(StateSpaceModels.VEHICLE_FATALITIES) |> DataFrame
35+
log_finland_fatalities = log.(finland_fatalities.ff)
36+
model = ExponentialSmoothing(log_finland_fatalities; trend = true)
37+
fit!(model)
38+
39+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, kf)
40+
@test length(rec) == 3
41+
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), model, ks)
42+
@test length(rec) == 3
2343
end

test/visualization/backtest.jl renamed to test/visualization/cross_validation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
log_air_passengers = log.(air_passengers.passengers)
44
model = BasicStructural(log_air_passengers, 12)
55
# forecasting
6-
b = backtest(model, 24, 110)
6+
b = cross_validation(model, 24, 110)
77
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), b, "str")
88
@test length(rec) == 2
99
end

0 commit comments

Comments
 (0)