@@ -100,7 +100,6 @@ def test_extract_components_from_idata(rng):
100100
101101 mod .build_statespace_graph (y )
102102
103- x0 , P0 , c , d , T , Z , R , H , Q = mod .unpack_statespace ()
104103 prior = pm .sample_prior_predictive (draws = 10 )
105104
106105 filter_prior = mod .sample_conditional_prior (prior )
@@ -110,3 +109,43 @@ def test_extract_components_from_idata(rng):
110109 missing = set (comp_states ) - set (expected_states )
111110
112111 assert len (missing ) == 0 , missing
112+
113+
114+ def test_extract_multiple_observed (rng ):
115+ time_idx = pd .date_range (start = "2000-01-01" , freq = "D" , periods = 100 )
116+ data = pd .DataFrame (rng .normal (size = (100 , 2 )), columns = ["a" , "b" ], index = time_idx )
117+
118+ y = pd .DataFrame (
119+ rng .normal (size = (100 , 3 )), columns = ["data_1" , "data_2" , "data_3" ], index = time_idx
120+ )
121+
122+ ll = st .LevelTrendComponent (name = "trend" , observed_state_names = ["data_1" , "data_2" ])
123+ season = st .FrequencySeasonality (
124+ name = "seasonal" , observed_state_names = ["data_1" ], season_length = 12 , n = 2 , innovations = False
125+ )
126+ reg = st .RegressionComponent (
127+ state_names = ["a" , "b" ], name = "exog" , observed_state_names = ["data_2" , "data_3" ]
128+ )
129+ me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_2" , "data_3" ])
130+ mod = (ll + season + reg + me ).build (verbose = True )
131+
132+ with pm .Model (coords = mod .coords ) as m :
133+ data_exog = pm .Data ("data_exog" , data .values )
134+
135+ x0 = pm .Normal ("x0" , dims = ["state" ])
136+ P0 = pm .Deterministic ("P0" , pt .eye (mod .k_states ), dims = ["state" , "state_aux" ])
137+ beta_exog = pm .Normal ("beta_exog" , dims = ["endog_exog" , "state_exog" ])
138+ initial_trend = pm .Normal ("initial_trend" , dims = ["endog_trend" , "state_trend" ])
139+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["endog_trend" , "trend_shock" ])
140+ seasonal_coefs = pm .Normal ("seasonal" , dims = ["seasonal_state" ])
141+ sigma_obs = pm .Exponential ("sigma_obs" , 1 , dims = ["endog_obs" ])
142+
143+ mod .build_statespace_graph (y )
144+
145+ prior = pm .sample_prior_predictive (draws = 10 )
146+
147+ filter_prior = mod .sample_conditional_prior (prior )
148+ comp_prior = mod .extract_components_from_idata (filter_prior )
149+ comp_states = comp_prior .filtered_prior .coords ["state" ].values
150+ # expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151+ # missing = set(comp_states) - set(expected_states)
0 commit comments