@@ -1072,6 +1072,168 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
10721072 )
10731073 return fig
10741074
1075+ def get_ts_contribution_posterior (
1076+ self , var_contribution : str , original_scale : bool = False
1077+ ) -> DataArray :
1078+ """Get the posterior distribution of the time series contributions of a given variable.
1079+
1080+ Parameters
1081+ ----------
1082+ var_contribution : str
1083+ The variable for which to get the contributions. It must be a valid variable
1084+ in the `fit_result` attribute.
1085+ original_scale : bool, optional
1086+ Whether to plot in the original scale.
1087+
1088+ Returns
1089+ -------
1090+ DataArray
1091+ The posterior distribution of the time series contributions.
1092+ """
1093+ contributions = self ._format_model_contributions (
1094+ var_contribution = var_contribution
1095+ )
1096+
1097+ if original_scale :
1098+ return apply_sklearn_transformer_across_dim (
1099+ data = contributions ,
1100+ func = self .get_target_transformer ().inverse_transform ,
1101+ dim_name = "date" ,
1102+ )
1103+
1104+ return contributions
1105+
1106+ def plot_components_contributions (
1107+ self , original_scale : bool = False , ** plt_kwargs : Any
1108+ ) -> plt .Figure :
1109+ """Plot the target variable and the posterior predictive model components in
1110+ the scaled space.
1111+
1112+ Parameters
1113+ ----------
1114+ original_scale : bool, optional
1115+ Whether to plot in the original scale.
1116+
1117+ **plt_kwargs
1118+ Additional keyword arguments to pass to `plt.subplots`.
1119+
1120+ Returns
1121+ -------
1122+ plt.Figure
1123+ """
1124+ channel_contributions = self .get_ts_contribution_posterior (
1125+ var_contribution = "channel_contributions" , original_scale = original_scale
1126+ )
1127+
1128+ means = [channel_contributions .mean (["chain" , "draw" ])]
1129+ contribution_vars = [
1130+ az .hdi (channel_contributions , hdi_prob = 0.94 ).channel_contributions
1131+ ]
1132+
1133+ for arg , var_contribution in zip (
1134+ ["control_columns" , "yearly_seasonality" ],
1135+ ["control_contributions" , "fourier_contributions" ],
1136+ strict = True ,
1137+ ):
1138+ if getattr (self , arg , None ):
1139+ contributions = self .get_ts_contribution_posterior (
1140+ var_contribution = var_contribution , original_scale = original_scale
1141+ )
1142+
1143+ means .append (contributions .mean (["chain" , "draw" ]))
1144+ contribution_vars .append (
1145+ az .hdi (contributions , hdi_prob = 0.94 )[var_contribution ]
1146+ )
1147+
1148+ fig , ax = plt .subplots (** plt_kwargs )
1149+
1150+ for i , (mean , hdi , var_contribution ) in enumerate (
1151+ zip (
1152+ means ,
1153+ contribution_vars ,
1154+ [
1155+ "channel_contribution" ,
1156+ "control_contribution" ,
1157+ "fourier_contribution" ,
1158+ ],
1159+ strict = False ,
1160+ )
1161+ ):
1162+ if self .X is not None :
1163+ ax .fill_between (
1164+ x = self .X [self .date_column ],
1165+ y1 = hdi .isel (hdi = 0 ),
1166+ y2 = hdi .isel (hdi = 1 ),
1167+ color = f"C{ i } " ,
1168+ alpha = 0.25 ,
1169+ label = f"$94\\ %$ HDI ({ var_contribution } )" ,
1170+ )
1171+ ax .plot (
1172+ np .asarray (self .X [self .date_column ]),
1173+ np .asarray (mean ),
1174+ color = f"C{ i } " ,
1175+ )
1176+ if self .X is not None :
1177+ intercept = az .extract (
1178+ self .fit_result , var_names = ["intercept" ], combined = False
1179+ )
1180+
1181+ if original_scale :
1182+ intercept = apply_sklearn_transformer_across_dim (
1183+ data = intercept ,
1184+ func = self .get_target_transformer ().inverse_transform ,
1185+ dim_name = "chain" ,
1186+ )
1187+
1188+ if intercept .ndim == 2 :
1189+ # Intercept has a stationary prior
1190+ intercept_hdi = np .repeat (
1191+ a = az .hdi (intercept ).intercept .data [None , ...],
1192+ repeats = self .X [self .date_column ].shape [0 ],
1193+ axis = 0 ,
1194+ )
1195+ elif intercept .ndim == 3 :
1196+ # Intercept has a time-varying prior
1197+ intercept_hdi = az .hdi (intercept ).intercept .data
1198+
1199+ ax .plot (
1200+ np .asarray (self .X [self .date_column ]),
1201+ np .full (len (self .X [self .date_column ]), intercept .mean ().data ),
1202+ color = f"C{ i + 1 } " ,
1203+ )
1204+ ax .fill_between (
1205+ x = self .X [self .date_column ],
1206+ y1 = intercept_hdi [:, 0 ],
1207+ y2 = intercept_hdi [:, 1 ],
1208+ color = f"C{ i + 1 } " ,
1209+ alpha = 0.25 ,
1210+ label = "$94\\ %$ HDI (intercept)" ,
1211+ )
1212+
1213+ y_to_plot = (
1214+ self .get_target_transformer ().inverse_transform (
1215+ np .asarray (self .preprocessed_data ["y" ]).reshape (- 1 , 1 )
1216+ )
1217+ if original_scale
1218+ else np .asarray (self .preprocessed_data ["y" ])
1219+ )
1220+
1221+ ylabel = self .output_var if original_scale else f"{ self .output_var } scaled"
1222+
1223+ ax .plot (
1224+ np .asarray (self .X [self .date_column ]),
1225+ y_to_plot ,
1226+ label = ylabel ,
1227+ color = "black" ,
1228+ )
1229+ ax .legend (loc = "upper center" , bbox_to_anchor = (0.5 , - 0.1 ), ncol = 3 )
1230+ ax .set (
1231+ title = "Posterior Predictive Model Components" ,
1232+ xlabel = "date" ,
1233+ ylabel = ylabel ,
1234+ )
1235+ return fig
1236+
10751237 def plot_channel_contributions_grid (
10761238 self ,
10771239 start : float ,
0 commit comments