@@ -733,8 +733,7 @@ def test_invalid_scenarios():
733
733
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
734
734
with pytest .raises (
735
735
ValueError ,
736
- match = "Scenario data for variable 'a' has the wrong number of columns. "
737
- "Expected 2, got 1" ,
736
+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
738
737
):
739
738
for data_type in [list , tuple , pd .Series ]:
740
739
ss_mod ._validate_scenario_data (data_type (np .zeros (10 )))
@@ -743,15 +742,14 @@ def test_invalid_scenarios():
743
742
# Providing irrevelant data raises
744
743
with pytest .raises (
745
744
ValueError ,
746
- match = "Scenario data provided for variable 'jk lol', which is not an exogenous " " variable" ,
745
+ match = "Scenario data provided for variable 'jk lol', which is not an exogenous variable" ,
747
746
):
748
747
ss_mod ._validate_scenario_data ({"jk lol" : np .zeros (10 )})
749
748
750
749
# Incorrect 2nd dimension of a non-dataframe
751
750
with pytest .raises (
752
751
ValueError ,
753
- match = "Scenario data for variable 'a' has the wrong number of columns. Expected "
754
- "2, got 1" ,
752
+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
755
753
):
756
754
scenario = np .zeros (10 ).tolist ()
757
755
ss_mod ._validate_scenario_data (scenario )
@@ -1017,3 +1015,13 @@ def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
1017
1015
1018
1016
assert forecasts .forecast_latent .shape [2 ] == n_periods
1019
1017
assert forecasts .forecast_observed .shape [2 ] == n_periods
1018
+
1019
+
1020
+ @pytest .mark .parametrize ("batch_size" , [(10 ,), (10 , 3 , 5 )])
1021
+ def test_insert_batched_rvs (ss_mod , batch_size ):
1022
+ with pm .Model ():
1023
+ rho = pm .Normal ("rho" , shape = batch_size )
1024
+ zeta = pm .Normal ("zeta" , shape = batch_size )
1025
+ ss_mod ._insert_random_variables ()
1026
+ matrices = ss_mod .unpack_statespace ()
1027
+ assert matrices [4 ].type .shape == (* batch_size , 2 , 2 )
0 commit comments