Skip to content

Commit a6d92f7

Browse files
Change from backtest to cross validation
1 parent e15df3b commit a6d92f7

File tree

8 files changed

+28
-28
lines changed

8 files changed

+28
-28
lines changed

docs/src/examples.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ smoother_output = kalman_smoother(model)
193193
plot(df.date, get_smoothed_state(smoother_output)[:, 2], label = "slope")
194194
```
195195

196-
## Backtest the forecasts of a model
196+
## Cross validation of the forecasts of a model
197197

198198
Often times users would like to compare the forecasting skill of different models. The function
199-
[`backtest`](@ref) makes it easy to make a rolling window scheme of estimations and forecasts
199+
[`cross_validation`](@ref) makes it easy to make a rolling window scheme of estimations and forecasts
200200
that allow users to track each model forecasting skill per lead time. A simple plot recipe is
201201
implemented to help users to interpret the results easily.
202202

@@ -210,6 +210,6 @@ using CSV, DataFrames
210210
air_passengers = CSV.File(StateSpaceModels.AIR_PASSENGERS) |> DataFrame
211211
log_air_passengers = log.(air_passengers.passengers)
212212
model = BasicStructural(log_air_passengers, 12)
213-
b = backtest(model, 24, 50)
213+
b = cross_validation(model, 24, 50)
214214
plot(b, "Basic structural model")
215215
```

docs/src/manual.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ isfitted
154154

155155
StateSpaceModels.jl has functions to make forecasts of the predictive densities multiple steps ahead and to
156156
simulate scenarios based on those forecasts. The package also has a functions to benchmark the model forecasts
157-
using backtest techniques.
157+
using cross_validation techniques.
158158

159159
```@docs
160160
forecast
161161
simulate_scenarios
162-
backtest
162+
cross_validation
163163
```
164164

165165
## Visualization

src/StateSpaceModels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ include("optimizers.jl")
3333
include("fit.jl")
3434
include("prints.jl")
3535
include("forecast.jl")
36-
include("backtest.jl")
36+
include("cross_validation.jl")
3737

3838
include("models/common.jl")
3939
include("models/locallevel.jl")
@@ -52,7 +52,7 @@ include("models/dar.jl")
5252

5353
include("visualization/forecast.jl")
5454
include("visualization/components.jl")
55-
include("visualization/backtest.jl")
55+
include("visualization/cross_validation.jl")
5656
include("visualization/diagnostics.jl")
5757

5858
# Exported types and structs
@@ -82,7 +82,7 @@ export UnobservedComponents
8282

8383
# Exported functions
8484
export auto_ets
85-
export backtest
85+
export cross_validation
8686
export constrain_box!
8787
export constrain_identity!
8888
export constrain_variance!

src/backtest.jl renamed to src/cross_validation.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
struct Backtest{Fl <: AbstractFloat}
1+
struct CrossValidation{Fl <: AbstractFloat}
22
abs_errors::Matrix{Fl}
33
mae::Vector{Fl}
44
crps_scores::Matrix{Fl}
55
mean_crps::Vector{Fl}
6-
function Backtest{Fl}(n::Int, steps_ahead::Int) where Fl
6+
function CrossValidation{Fl}(n::Int, steps_ahead::Int) where Fl
77
abs_errors = Matrix{Fl}(undef, steps_ahead, n)
88
crps_scores = Matrix{Fl}(undef, steps_ahead, n)
99
mae = Vector{Fl}(undef, steps_ahead)
@@ -34,27 +34,27 @@ function evaluate_crps(y::Vector{Fl}, scenarios::Matrix{Fl}) where {Fl}
3434
end
3535

3636
"""
37-
backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
37+
cross_validation(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
3838
n_scenarios::Int = 10_000,
3939
filter::KalmanFilter=default_filter(model),
4040
optimizer::Optimizer=default_optimizer(model)) where Fl
4141
4242
Makes rolling window estimating and forecasting to benchmark the forecasting skill of the model
4343
in for different time periods and different lead times. The function returns a struct with the MAE
44-
and mean CRPS per lead time. See more on [Backtest the forecasts of a model](@ref)
44+
and mean CRPS per lead time. See more on [CrossValidation the forecasts of a model](@ref)
4545
4646
# References
4747
* DTU course "31761 - Renewables in electricity markets" available on youtube https://www.youtube.com/watch?v=Ffo8XilZAZw&t=556s
4848
"""
49-
function backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
49+
function cross_validation(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
5050
n_scenarios::Int = 10_000,
5151
filter::KalmanFilter=default_filter(model),
5252
optimizer::Optimizer=default_optimizer(model))
5353
Fl = typeof_model_elements(model)
5454
num_mle = length(model.system.y) - start_idx - steps_ahead
55-
b = Backtest{Fl}(num_mle, steps_ahead)
55+
cv = CrossValidation{Fl}(num_mle, steps_ahead)
5656
for i in 1:num_mle
57-
println("Backtest: step $i of $num_mle")
57+
println("CrossValidation: step $i of $num_mle")
5858
y_to_fit = model.system.y[1:start_idx - 1 + i]
5959
y_to_verify = model.system.y[start_idx + i:start_idx - 1 + i + steps_ahead]
6060
model_to_fit = reinstantiate(model, y_to_fit)
@@ -64,12 +64,12 @@ function backtest(model::StateSpaceModel, steps_ahead::Int, start_idx::Int;
6464
expected_value_vector = forecast_expected_value(forec)[:]
6565
abs_errors = evaluate_abs_error(y_to_verify, expected_value_vector)
6666
crps_scores = evaluate_crps(y_to_verify, scenarios[:, 1, :])
67-
b.abs_errors[:, i] = abs_errors
68-
b.crps_scores[:, i] = crps_scores
67+
cv.abs_errors[:, i] = abs_errors
68+
cv.crps_scores[:, i] = crps_scores
6969
end
7070
for i in 1:steps_ahead
71-
b.mae[i] = mean(b.abs_errors[i, :])
72-
b.mean_crps[i] = mean(b.crps_scores[i, :])
71+
cv.mae[i] = mean(cv.abs_errors[i, :])
72+
cv.mean_crps[i] = mean(cv.crps_scores[i, :])
7373
end
74-
return b
74+
return cv
7575
end

src/models/naive_models.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,13 @@ function reinstantiate(model::ExperimentalSeasonalNaive, y::Vector{<:Real})
226226
return ExperimentalSeasonalNaive(y, model.seasonal; S = model.S)
227227
end
228228

229-
function backtest(model::NaiveModel, steps_ahead::Int, start_idx::Int;
229+
function cross_validation(model::NaiveModel, steps_ahead::Int, start_idx::Int;
230230
n_scenarios::Int = 10_000)
231231
Fl = typeof_model_elements(model)
232232
num_fits = length(model.y) - start_idx - steps_ahead
233-
b = Backtest{Fl}(num_fits, steps_ahead)
233+
b = CrossValidation{Fl}(num_fits, steps_ahead)
234234
for i in 1:num_fits
235-
println("Backtest: step $i of $num_fits")
235+
println("CrossValidation: step $i of $num_fits")
236236
y_to_fit = model.y[1:start_idx - 1 + i]
237237
y_to_verify = model.y[start_idx + i:start_idx - 1 + i + steps_ahead]
238238
model_to_fit = reinstantiate(model, y_to_fit)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
@recipe function f(b::Backtest, name::String)
1+
@recipe function f(cv::CrossValidation, name::String)
22
xguide := "lead times"
33
@series begin
44
seriestype := :path
55
label := "MAE " * name
66
marker := :circle
7-
b.mae
7+
cv.mae
88
end
99
@series begin
1010
seriestype := :path
1111
label := "Mean CRPS " * name
1212
marker := :circle
13-
b.mean_crps
13+
cv.mean_crps
1414
end
1515
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ include("models/dar.jl")
3131
# Visualization
3232
include("visualization/forecast.jl")
3333
include("visualization/components.jl")
34-
include("visualization/backtest.jl")
34+
include("visualization/cross_validation.jl")
3535
include("visualization/diagnostics.jl")

test/visualization/backtest.jl renamed to test/visualization/cross_validation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
log_air_passengers = log.(air_passengers.passengers)
44
model = BasicStructural(log_air_passengers, 12)
55
# forecasting
6-
b = backtest(model, 24, 110)
6+
b = cross_validation(model, 24, 110)
77
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), b, "str")
88
@test length(rec) == 2
99
end

0 commit comments

Comments
 (0)