Skip to content

Commit db74f2d

Browse files
fix save error
1 parent 087b14d commit db74f2d

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

src/cross_validation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function cross_validation(model::StateSpaceModel, steps_ahead::Int, start_idx::I
5858
y_to_fit = model.system.y[1:start_idx - 1 + i]
5959
y_to_verify = model.system.y[start_idx + i:start_idx - 1 + i + steps_ahead]
6060
model_to_fit = reinstantiate(model, y_to_fit)
61-
fit!(model_to_fit; filter=filter, optimizer=optimizer, save_results=false)
61+
fit!(model_to_fit; filter=filter, optimizer=optimizer, save_hyperparameter_distribution_results=false)
6262
forec = forecast(model_to_fit, steps_ahead; filter=filter)
6363
scenarios = simulate_scenarios(model_to_fit, steps_ahead, n_scenarios; filter=filter)
6464
expected_value_vector = forecast_expected_value(forec)[:]

src/fit.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
model::StateSpaceModel;
44
filter::KalmanFilter=default_filter(model),
55
optimizer::Optimizer=Optimizer(Optim.LBFGS()),
6-
save_results::Bool=true
6+
save_hyperparameter_distribution_results::Bool=true
77
)
88
99
Estimate the state-space model parameters via maximum likelihood. The resulting optimal
@@ -30,7 +30,7 @@ function fit!(
3030
model::StateSpaceModel;
3131
filter::KalmanFilter=default_filter(model),
3232
optimizer::Optimizer=default_optimizer(model),
33-
save_results::Bool=true
33+
save_hyperparameter_distribution_results::Bool=true
3434
)
3535
isfitted(model) && return model
3636
@assert has_fit_methods(typeof(model))
@@ -48,7 +48,7 @@ function fit!(
4848
opt_hyperparameters = opt.minimizer
4949
update_model_hyperparameters!(model, opt_hyperparameters)
5050

51-
if save_results
51+
if save_hyperparameter_distribution_results
5252
numerical_hessian = Optim.hessian!(func, opt_hyperparameters)
5353
try
5454
std_err = numerical_hessian |> pinv |> diag .|> sqrt
@@ -60,11 +60,15 @@ function fit!(
6060
"If you are interested in estimates of the distribution of ther hyperparameters we advise you to" *
6161
"change the optimization algorithm by using the kwarg fit(...; optimizer = "*
6262
"Optimizer(StateSpaceModels.Optim.THE_METHOD_OF_YOUR_CHOICE()))" *
63-
"The list of possible algorithms can be found on this link https://julianlsolvers.github.io/Optim.jl/stable/#"
63+
"The list of possible algorithms can be found on this link https://julianlsolvers.github.io/Optim.jl/stable/#" *
64+
"Otherwise you can simply skip this proccess by using fit(...; save_hyperparameter_distribution_results=false) "
6465
)
6566
std_err = fill(NaN, number_hyperparameters(model))
6667
fill_results!(model, opt_loglikelihood, std_err)
6768
end
69+
else
70+
std_err = fill(NaN, number_hyperparameters(model))
71+
fill_results!(model, opt_loglikelihood, std_err)
6872
end
6973
return model
7074
end

src/models/sarima.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ function fit_candidate_models!(candidate_models::Vector{SARIMA}, show_trace::Boo
761761
non_converged_models = Int[]
762762
for (i, model) in enumerate(candidate_models)
763763
try
764-
fit!(model; save_results=false)
764+
fit!(model; save_hyperparameter_distribution_results=false)
765765
if isnan(model.results.llk)
766766
show_trace && println(model, " - diverged")
767767
push!(non_converged_models, i)

0 commit comments

Comments
 (0)