@@ -126,7 +126,7 @@ def test_extract_multiple_observed(rng):
126
126
reg = st .RegressionComponent (
127
127
state_names = ["a" , "b" ], name = "exog" , observed_state_names = ["data_2" , "data_3" ]
128
128
)
129
- me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_2" , " data_3" ])
129
+ me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_3" ])
130
130
mod = (ll + season + reg + me ).build (verbose = True )
131
131
132
132
with pm .Model (coords = mod .coords ) as m :
@@ -147,5 +147,18 @@ def test_extract_multiple_observed(rng):
147
147
filter_prior = mod .sample_conditional_prior (prior )
148
148
comp_prior = mod .extract_components_from_idata (filter_prior )
149
149
comp_states = comp_prior .filtered_prior .coords ["state" ].values
150
- # expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151
- # missing = set(comp_states) - set(expected_states)
150
+
151
+ expected_states = [
152
+ "trend[level[data_1]]" ,
153
+ "trend[trend[data_1]]" ,
154
+ "trend[level[data_2]]" ,
155
+ "trend[trend[data_2]]" ,
156
+ "seasonal" ,
157
+ "exog[a[data_2]]" ,
158
+ "exog[b[data_2]]" ,
159
+ "exog[a[data_3]]" ,
160
+ "exog[b[data_3]]" ,
161
+ ]
162
+
163
+ missing = set (comp_states ) - set (expected_states )
164
+ assert len (missing ) == 0 , missing
0 commit comments