Skip to content

Commit 945e152

Browse files
author
andre_ramos
committed
change simulation to univariate distributions
1 parent 1e95af3 commit 945e152

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StateSpaceLearning"
22
uuid = "971c4b7c-2c4e-4bac-8525-e842df3cde7b"
33
authors = ["andreramosfc <[email protected]>"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/fit_forecast.jl

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,27 @@ function simulate(
196196
@assert seasonal_innovation_simulation >= 0 "seasonal_innovation_simulation must be a non-negative integer"
197197
@assert isfitted(model) "Model must be fitted before simulation"
198198

199-
prediction = StateSpaceLearning.forecast(
199+
prediction = forecast(
200200
model, steps_ahead; Exogenous_Forecast=Exogenous_Forecast
201201
)
202202

203-
is_univariate = typeof(model.output) == StateSpaceLearning.Output
203+
is_univariate = typeof(model.output) == Output
204204

205205
simulation_X = zeros(steps_ahead, 0)
206206
valid_indexes =
207207
is_univariate ? model.output.valid_indexes : model.output[1].valid_indexes
208208
components_matrix = zeros(length(valid_indexes), 0)
209209
N_components = 1
210210

211-
model_innovations = StateSpaceLearning.get_model_innovations(model)
211+
model_innovations = get_model_innovations(model)
212212
for innovation in model_innovations
213213
simulation_X = hcat(
214214
simulation_X,
215-
StateSpaceLearning.get_innovation_simulation_X(model, innovation, steps_ahead)[
215+
get_innovation_simulation_X(model, innovation, steps_ahead)[
216216
(end - steps_ahead):(end - 1), (end - steps_ahead + 1):end
217217
],
218218
)
219-
comp = StateSpaceLearning.fill_innovation_coefs(model, innovation, valid_indexes)
219+
comp = fill_innovation_coefs(model, innovation, valid_indexes)
220220
components_matrix = hcat(components_matrix, comp)
221221
N_components += 1
222222
end
@@ -242,7 +242,11 @@ function simulate(
242242
end
243243

244244
if seasonal_innovation_simulation == 0
245-
= cov(components_matrix)
245+
= if is_univariate
246+
Diagonal([var(components_matrix[:, i]) for i in 1:N_components])
247+
else
248+
Diagonal([var(components_matrix[:, i]) for i in 1:N_mv_components])
249+
end
246250
for i in 1:steps_ahead
247251
MV_dist_vec[i] = if is_univariate
248252
MvNormal(zeros(N_components), ∑)
@@ -270,12 +274,27 @@ function simulate(
270274
end
271275
else
272276
start_seasonal_term = (size(components_matrix, 1) % seasonal_innovation_simulation)
273-
for i in 1:steps_ahead
274-
= cov(
275-
components_matrix[
276-
(i + start_seasonal_term):seasonal_innovation_simulation:end, :,
277-
],
278-
)
277+
for i in 1:seasonal_innovation_simulation
278+
= if is_univariate
279+
Diagonal([
280+
var(
281+
components_matrix[
282+
(i + start_seasonal_term):seasonal_innovation_simulation:end,
283+
j,
284+
],
285+
) for j in 1:N_components
286+
])
287+
else
288+
Diagonal([
289+
var(
290+
components_matrix[
291+
(i + start_seasonal_term):seasonal_innovation_simulation:end,
292+
j,
293+
],
294+
) for j in 1:N_mv_components
295+
])
296+
end
297+
279298
MV_dist_vec[i] = if is_univariate
280299
MvNormal(zeros(N_components), ∑)
281300
else
@@ -313,6 +332,20 @@ function simulate(
313332
end
314333
end
315334
end
335+
for i in (seasonal_innovation_simulation + 1):steps_ahead
336+
MV_dist_vec[i] = MV_dist_vec[i - seasonal_innovation_simulation]
337+
if model.outlier
338+
if is_univariate
339+
o_noises[i, :] = o_noises[i - seasonal_innovation_simulation, :]
340+
else
341+
for j in eachindex(model.output)
342+
o_noises[j][i, :] = o_noises[j][
343+
i - seasonal_innovation_simulation, :,
344+
]
345+
end
346+
end
347+
end
348+
end
316349
end
317350

318351
simulation = if is_univariate

0 commit comments

Comments
 (0)