@@ -710,8 +710,7 @@ def test_invalid_scenarios():
710
710
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
711
711
with pytest .raises (
712
712
ValueError ,
713
- match = "Scenario data for variable 'a' has the wrong number of columns. "
714
- "Expected 2, got 1" ,
713
+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
715
714
):
716
715
for data_type in [list , tuple , pd .Series ]:
717
716
ss_mod ._validate_scenario_data (data_type (np .zeros (10 )))
@@ -720,15 +719,14 @@ def test_invalid_scenarios():
720
719
# Providing irrevelant data raises
721
720
with pytest .raises (
722
721
ValueError ,
723
- match = "Scenario data provided for variable 'jk lol', which is not an exogenous " " variable" ,
722
+ match = "Scenario data provided for variable 'jk lol', which is not an exogenous variable" ,
724
723
):
725
724
ss_mod ._validate_scenario_data ({"jk lol" : np .zeros (10 )})
726
725
727
726
# Incorrect 2nd dimension of a non-dataframe
728
727
with pytest .raises (
729
728
ValueError ,
730
- match = "Scenario data for variable 'a' has the wrong number of columns. Expected "
731
- "2, got 1" ,
729
+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
732
730
):
733
731
scenario = np .zeros (10 ).tolist ()
734
732
ss_mod ._validate_scenario_data (scenario )
@@ -870,3 +868,13 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870
868
regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
871
869
872
870
assert_allclose (regression_effect , regression_effect_expected )
871
+
872
+
873
+ @pytest .mark .parametrize ("batch_size" , [(10 ,), (10 , 3 , 5 )])
874
+ def test_insert_batched_rvs (ss_mod , batch_size ):
875
+ with pm .Model ():
876
+ rho = pm .Normal ("rho" , shape = batch_size )
877
+ zeta = pm .Normal ("zeta" , shape = batch_size )
878
+ ss_mod ._insert_random_variables ()
879
+ matrices = ss_mod .unpack_statespace ()
880
+ assert matrices [4 ].type .shape == (* batch_size , 2 , 2 )
0 commit comments