@@ -100,7 +100,6 @@ def test_extract_components_from_idata(rng):
100
100
101
101
mod .build_statespace_graph (y )
102
102
103
- x0 , P0 , c , d , T , Z , R , H , Q = mod .unpack_statespace ()
104
103
prior = pm .sample_prior_predictive (draws = 10 )
105
104
106
105
filter_prior = mod .sample_conditional_prior (prior )
@@ -110,3 +109,43 @@ def test_extract_components_from_idata(rng):
110
109
missing = set (comp_states ) - set (expected_states )
111
110
112
111
assert len (missing ) == 0 , missing
112
+
113
+
114
+ def test_extract_multiple_observed (rng ):
115
+ time_idx = pd .date_range (start = "2000-01-01" , freq = "D" , periods = 100 )
116
+ data = pd .DataFrame (rng .normal (size = (100 , 2 )), columns = ["a" , "b" ], index = time_idx )
117
+
118
+ y = pd .DataFrame (
119
+ rng .normal (size = (100 , 3 )), columns = ["data_1" , "data_2" , "data_3" ], index = time_idx
120
+ )
121
+
122
+ ll = st .LevelTrendComponent (name = "trend" , observed_state_names = ["data_1" , "data_2" ])
123
+ season = st .FrequencySeasonality (
124
+ name = "seasonal" , observed_state_names = ["data_1" ], season_length = 12 , n = 2 , innovations = False
125
+ )
126
+ reg = st .RegressionComponent (
127
+ state_names = ["a" , "b" ], name = "exog" , observed_state_names = ["data_2" , "data_3" ]
128
+ )
129
+ me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_2" , "data_3" ])
130
+ mod = (ll + season + reg + me ).build (verbose = True )
131
+
132
+ with pm .Model (coords = mod .coords ) as m :
133
+ data_exog = pm .Data ("data_exog" , data .values )
134
+
135
+ x0 = pm .Normal ("x0" , dims = ["state" ])
136
+ P0 = pm .Deterministic ("P0" , pt .eye (mod .k_states ), dims = ["state" , "state_aux" ])
137
+ beta_exog = pm .Normal ("beta_exog" , dims = ["endog_exog" , "state_exog" ])
138
+ initial_trend = pm .Normal ("initial_trend" , dims = ["endog_trend" , "state_trend" ])
139
+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["endog_trend" , "trend_shock" ])
140
+ seasonal_coefs = pm .Normal ("seasonal" , dims = ["seasonal_state" ])
141
+ sigma_obs = pm .Exponential ("sigma_obs" , 1 , dims = ["endog_obs" ])
142
+
143
+ mod .build_statespace_graph (y )
144
+
145
+ prior = pm .sample_prior_predictive (draws = 10 )
146
+
147
+ filter_prior = mod .sample_conditional_prior (prior )
148
+ comp_prior = mod .extract_components_from_idata (filter_prior )
149
+ comp_states = comp_prior .filtered_prior .coords ["state" ].values
150
+ # expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151
+ # missing = set(comp_states) - set(expected_states)
0 commit comments