@@ -126,7 +126,7 @@ def test_extract_multiple_observed(rng):
126126 reg = st .RegressionComponent (
127127 state_names = ["a" , "b" ], name = "exog" , observed_state_names = ["data_2" , "data_3" ]
128128 )
129- me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_2" , " data_3" ])
129+ me = st .MeasurementError ("obs" , observed_state_names = ["data_1" , "data_3" ])
130130 mod = (ll + season + reg + me ).build (verbose = True )
131131
132132 with pm .Model (coords = mod .coords ) as m :
@@ -147,5 +147,18 @@ def test_extract_multiple_observed(rng):
147147 filter_prior = mod .sample_conditional_prior (prior )
148148 comp_prior = mod .extract_components_from_idata (filter_prior )
149149 comp_states = comp_prior .filtered_prior .coords ["state" ].values
150- # expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151- # missing = set(comp_states) - set(expected_states)
150+
151+ expected_states = [
152+ "trend[level[data_1]]" ,
153+ "trend[trend[data_1]]" ,
154+ "trend[level[data_2]]" ,
155+ "trend[trend[data_2]]" ,
156+ "seasonal" ,
157+ "exog[a[data_2]]" ,
158+ "exog[b[data_2]]" ,
159+ "exog[a[data_3]]" ,
160+ "exog[b[data_3]]" ,
161+ ]
162+
163+ missing = set (comp_states ) - set (expected_states )
164+ assert len (missing ) == 0 , missing
0 commit comments