Skip to content

Commit 1e806e5

Browse files
Fix orders constraints
1 parent 1dfc0c2 commit 1e806e5

File tree

3 files changed

+81
-16
lines changed

3 files changed

+81
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StateSpaceModels"
22
uuid = "99342f36-827c-5390-97c9-d7f9ee765c78"
33
authors = ["raphaelsaavedra <[email protected]>, guilhermebodin <[email protected]>, mariohsouto"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/models/sarima.jl

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,52 @@ function add_model_with_changed_constant!(candidate_models, visited_models)
877877
return candidate_models
878878
end
879879

880+
function add_first_non_seasonal_models!(
881+
candidate_models::Vector{SARIMA},
882+
y::Vector{Fl},
883+
d::Int,
884+
include_mean::Bool,
885+
max_p::Int,
886+
max_q::Int
887+
) where Fl <: AbstractFloat
888+
if max_p >= 2 && max_q >= 2
889+
push!(candidate_models, SARIMA(y; order = (2, d, 2), include_mean = include_mean, suppress_warns = true))
890+
end
891+
if max_p >= 1
892+
push!(candidate_models, SARIMA(y; order = (1, d, 0), include_mean = include_mean, suppress_warns = true))
893+
end
894+
if max_q >= 1
895+
push!(candidate_models, SARIMA(y; order = (0, d, 1), include_mean = include_mean, suppress_warns = true))
896+
end
897+
push!(candidate_models, SARIMA(y; order = (0, d, 0), include_mean = include_mean, suppress_warns = true))
898+
return candidate_models
899+
end
900+
901+
function add_first_seasonal_models!(
902+
candidate_models::Vector{SARIMA},
903+
y::Vector{Fl},
904+
d::Int,
905+
D::Int,
906+
include_mean::Bool,
907+
max_p::Int,
908+
max_q::Int,
909+
max_P::Int,
910+
max_Q::Int,
911+
seasonal::Int
912+
) where Fl <: AbstractFloat
913+
if max_p >= 2 && max_q >= 2 && max_P >= 1 && max_Q >= 1
914+
push!(candidate_models, SARIMA(y; order = (2, d, 2), seasonal_order = (1, D, 1, seasonal) , include_mean = include_mean, suppress_warns = true))
915+
end
916+
if max_p >= 1 && max_P >= 1
917+
push!(candidate_models, SARIMA(y; order = (1, d, 0), seasonal_order = (1, D, 0, seasonal) , include_mean = include_mean, suppress_warns = true))
918+
end
919+
if max_q >= 1 && max_Q >= 1
920+
push!(candidate_models, SARIMA(y; order = (0, d, 1), seasonal_order = (0, D, 1, seasonal) , include_mean = include_mean, suppress_warns = true))
921+
end
922+
push!(candidate_models, SARIMA(y; order = (0, d, 0), seasonal_order = (0, D, 0, seasonal) , include_mean = include_mean, suppress_warns = true))
923+
return candidate_models
924+
end
925+
880926
"""
881927
auto_arima(y::Vector{Fl};
882928
seasonal::Int = 0,
@@ -928,13 +974,13 @@ function auto_arima(y::Vector{Fl};
928974
@assert D <= max_D
929975
@assert d >= -1
930976
@assert d <= max_d
931-
@assert max_p > 0
932-
@assert max_q > 0
933-
@assert max_d > 0
934-
@assert max_P > 0
935-
@assert max_D > 0
936-
@assert max_Q > 0
937-
@assert max_order > 0
977+
@assert max_p >= 0
978+
@assert max_q >= 0
979+
@assert max_d >= 0
980+
@assert max_P >= 0
981+
@assert max_D >= 0
982+
@assert max_Q >= 0
983+
@assert max_order >= 0
938984
@assert information_criteria in ["aic", "aicc", "bic"]
939985
@assert integration_test in ["kpss"]
940986
@assert seasonal_integration_test in ["seas", "ch"]
@@ -958,15 +1004,27 @@ function auto_arima(y::Vector{Fl};
9581004

9591005
# fit the first four models
9601006
if seasonal == 0
961-
push!(candidate_models, SARIMA(y; order = (2, d, 2), include_mean = include_mean, suppress_warns = true))
962-
push!(candidate_models, SARIMA(y; order = (0, d, 0), include_mean = include_mean, suppress_warns = true))
963-
push!(candidate_models, SARIMA(y; order = (1, d, 0), include_mean = include_mean, suppress_warns = true))
964-
push!(candidate_models, SARIMA(y; order = (0, d, 1), include_mean = include_mean, suppress_warns = true))
1007+
add_first_non_seasonal_models!(
1008+
candidate_models,
1009+
y,
1010+
d,
1011+
include_mean,
1012+
max_p,
1013+
max_q
1014+
)
9651015
else
966-
push!(candidate_models, SARIMA(y; order = (2, d, 2), seasonal_order = (1, D, 1, seasonal) , include_mean = include_mean, suppress_warns = true))
967-
push!(candidate_models, SARIMA(y; order = (0, d, 0), seasonal_order = (0, D, 0, seasonal) , include_mean = include_mean, suppress_warns = true))
968-
push!(candidate_models, SARIMA(y; order = (1, d, 0), seasonal_order = (1, D, 0, seasonal) , include_mean = include_mean, suppress_warns = true))
969-
push!(candidate_models, SARIMA(y; order = (0, d, 1), seasonal_order = (0, D, 1, seasonal) , include_mean = include_mean, suppress_warns = true))
1016+
add_first_seasonal_models!(
1017+
candidate_models,
1018+
y,
1019+
d,
1020+
D,
1021+
include_mean,
1022+
max_p,
1023+
max_q,
1024+
max_P,
1025+
max_Q,
1026+
seasonal
1027+
)
9701028
end
9711029

9721030
fit_candidate_models!(candidate_models, show_trace)

test/models/sarima.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@
101101
@test model.order.q == 1
102102
@test model.include_mean == false
103103

104+
model = auto_arima(dinternet; max_q = 0)
105+
@test model.order.q == 0
106+
107+
model = auto_arima(dinternet; max_q = 0, max_p = 0)
108+
@test model.order.p == 0
109+
@test model.order.q == 0
110+
104111
nile = CSV.File(StateSpaceModels.NILE) |> DataFrame
105112
model = auto_arima(nile.flow; d = 1, show_trace = true)
106113
@test model.order.p == 1

0 commit comments

Comments
 (0)