@@ -870,3 +870,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870870 regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
871871
872872 assert_allclose (regression_effect , regression_effect_expected )
873+
874+
875+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values." )
876+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
877+ def test_foreacast_valid_index (rng ):
878+ # Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
879+
880+ index = pd .date_range (start = "2023-05-01" , end = "2025-01-29" , freq = "D" )
881+ T , k = len (index ), 2
882+ data = np .zeros ((T , k ))
883+ idx = rng .choice (T , size = 10 , replace = False )
884+ cols = rng .choice (k , size = 10 , replace = True )
885+
886+ data [idx , cols ] = 1
887+
888+ df_holidays = pd .DataFrame (data , index = index , columns = ["Holiday 1" , "Holiday 2" ])
889+
890+ data = rng .normal (size = (T , 1 ))
891+ nan_locs = rng .choice (T , size = 10 , replace = False )
892+ data [nan_locs ] = np .nan
893+ y = pd .DataFrame (data , index = index , columns = ["sales" ])
894+
895+ level_trend = st .LevelTrendComponent (order = 1 , innovations_order = [0 ])
896+ weekly_seasonality = st .TimeSeasonality (
897+ season_length = 7 ,
898+ state_names = ["Sun" , "Mon" , "Tues" , "Wed" , "Thu" , "Fri" , "Sat" ],
899+ innovations = True ,
900+ remove_first_state = False ,
901+ )
902+ quarterly_seasonality = st .FrequencySeasonality (season_length = 365 , n = 2 , innovations = True )
903+ ar1 = st .AutoregressiveComponent (order = 1 )
904+ me = st .MeasurementError ()
905+
906+ exog = st .RegressionComponent (
907+ name = "exog" , # Name of this exogenous variable component
908+ k_exog = 2 , # Only one exogenous variable now
909+ innovations = False , # Typically fixed effect (no stochastic evolution)
910+ state_names = df_holidays .columns .tolist (),
911+ )
912+
913+ combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
914+ ss_mod = combined_model .build ()
915+
916+ with pm .Model (coords = ss_mod .coords ) as struct_model :
917+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 10 , dims = ["state" ])
918+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
919+
920+ initial_trend = pm .Normal ("initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["trend_state" ])
921+ # sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only
922+
923+ Seasonal_coefs = pm .ZeroSumNormal (
924+ "Seasonal[s=7]_coefs" , sigma = 0.5 , dims = ["Seasonal[s=7]_state" ]
925+ ) # DOW dev. from weekly mean
926+ sigma_Seasonal = pm .Gamma (
927+ "sigma_Seasonal[s=7]" , alpha = 2 , beta = 1
928+ ) # How much this dev. can dev.
929+
930+ Frequency_coefs = pm .Normal (
931+ "Frequency[s=365, n=2]" , mu = 0 , sigma = 0.5 , dims = ["Frequency[s=365, n=2]_state" ]
932+ ) # amplitudes in short-term (weekly noise culprit)
933+ sigma_Frequency = pm .Gamma (
934+ "sigma_Frequency[s=365, n=2]" , alpha = 2 , beta = 1
935+ ) # smoothness & adaptability over time
936+
937+ ar_params = pm .Laplace ("ar_params" , mu = 0 , b = 0.2 , dims = ["ar_lag" ])
938+ sigma_ar = pm .Gamma ("sigma_ar" , alpha = 2 , beta = 1 )
939+
940+ sigma_measurement_error = pm .HalfStudentT ("sigma_MeasurementError" , nu = 3 , sigma = 1 )
941+
942+ data_exog = pm .Data ("data_exog" , df_holidays .values , dims = ["time" , "exog_state" ])
943+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
944+
945+ ss_mod .build_statespace_graph (y , mode = "JAX" )
946+
947+ idata = pm .sample_prior_predictive ()
948+
949+ post = ss_mod .sample_conditional_prior (idata )
950+
951+ # Define start date and forecast period
952+ start_date , n_periods = pd .to_datetime ("2024-4-15" ), 8
953+
954+ # Extract exogenous data for the forecast period
955+ scenario = {
956+ "data_exog" : pd .DataFrame (
957+ df_holidays .loc [start_date :].iloc [:n_periods ], columns = df_holidays .columns
958+ )
959+ }
960+
961+ # Generate the forecast
962+ forecasts = ss_mod .forecast (idata .prior , scenario = scenario , use_scenario_index = True )
0 commit comments