Skip to content

Commit e3366b0

Browse files
Merge pull request #44 from LAMPSPUC/univariate_distributions
change simulation to univariate distributions
2 parents 1e95af3 + 3e8858c commit e3366b0

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StateSpaceLearning"
22
uuid = "971c4b7c-2c4e-4bac-8525-e842df3cde7b"
33
authors = ["andreramosfc <[email protected]>"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/fit_forecast.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,25 @@ function simulate(
196196
@assert seasonal_innovation_simulation >= 0 "seasonal_innovation_simulation must be a non-negative integer"
197197
@assert isfitted(model) "Model must be fitted before simulation"
198198

199-
prediction = StateSpaceLearning.forecast(
200-
model, steps_ahead; Exogenous_Forecast=Exogenous_Forecast
201-
)
199+
prediction = forecast(model, steps_ahead; Exogenous_Forecast=Exogenous_Forecast)
202200

203-
is_univariate = typeof(model.output) == StateSpaceLearning.Output
201+
is_univariate = typeof(model.output) == Output
204202

205203
simulation_X = zeros(steps_ahead, 0)
206204
valid_indexes =
207205
is_univariate ? model.output.valid_indexes : model.output[1].valid_indexes
208206
components_matrix = zeros(length(valid_indexes), 0)
209207
N_components = 1
210208

211-
model_innovations = StateSpaceLearning.get_model_innovations(model)
209+
model_innovations = get_model_innovations(model)
212210
for innovation in model_innovations
213211
simulation_X = hcat(
214212
simulation_X,
215-
StateSpaceLearning.get_innovation_simulation_X(model, innovation, steps_ahead)[
213+
get_innovation_simulation_X(model, innovation, steps_ahead)[
216214
(end - steps_ahead):(end - 1), (end - steps_ahead + 1):end
217215
],
218216
)
219-
comp = StateSpaceLearning.fill_innovation_coefs(model, innovation, valid_indexes)
217+
comp = fill_innovation_coefs(model, innovation, valid_indexes)
220218
components_matrix = hcat(components_matrix, comp)
221219
N_components += 1
222220
end
@@ -242,7 +240,11 @@ function simulate(
242240
end
243241

244242
if seasonal_innovation_simulation == 0
245-
= cov(components_matrix)
243+
= if is_univariate
244+
Diagonal([var(components_matrix[:, i]) for i in 1:N_components])
245+
else
246+
Diagonal([var(components_matrix[:, i]) for i in 1:N_mv_components])
247+
end
246248
for i in 1:steps_ahead
247249
MV_dist_vec[i] = if is_univariate
248250
MvNormal(zeros(N_components), ∑)
@@ -270,12 +272,27 @@ function simulate(
270272
end
271273
else
272274
start_seasonal_term = (size(components_matrix, 1) % seasonal_innovation_simulation)
273-
for i in 1:steps_ahead
274-
= cov(
275-
components_matrix[
276-
(i + start_seasonal_term):seasonal_innovation_simulation:end, :,
277-
],
278-
)
275+
for i in 1:seasonal_innovation_simulation
276+
= if is_univariate
277+
Diagonal([
278+
var(
279+
components_matrix[
280+
(i + start_seasonal_term):seasonal_innovation_simulation:end,
281+
j,
282+
],
283+
) for j in 1:N_components
284+
])
285+
else
286+
Diagonal([
287+
var(
288+
components_matrix[
289+
(i + start_seasonal_term):seasonal_innovation_simulation:end,
290+
j,
291+
],
292+
) for j in 1:N_mv_components
293+
])
294+
end
295+
279296
MV_dist_vec[i] = if is_univariate
280297
MvNormal(zeros(N_components), ∑)
281298
else
@@ -313,6 +330,20 @@ function simulate(
313330
end
314331
end
315332
end
333+
for i in (seasonal_innovation_simulation + 1):steps_ahead
334+
MV_dist_vec[i] = MV_dist_vec[i - seasonal_innovation_simulation]
335+
if model.outlier
336+
if is_univariate
337+
o_noises[i, :] = o_noises[i - seasonal_innovation_simulation, :]
338+
else
339+
for j in eachindex(model.output)
340+
o_noises[j][i, :] = o_noises[j][
341+
i - seasonal_innovation_simulation, :,
342+
]
343+
end
344+
end
345+
end
346+
end
316347
end
317348

318349
simulation = if is_univariate

0 commit comments

Comments
 (0)