@@ -1252,108 +1252,207 @@ def plot_sensitivity_analysis(
1252
1252
ax : plt .Axes | None = None ,
1253
1253
marginal : bool = False ,
1254
1254
percentage : bool = False ,
1255
- ) -> plt .Axes :
1255
+ sharey : bool = True ,
1256
+ ) -> tuple [Figure , NDArray [Axes ]] | plt .Axes :
1256
1257
"""
1257
- Plot the counterfactual uplift or marginal effects curve.
1258
+ Plot counterfactual uplift or marginal effects curves.
1259
+
1260
+ Handles additional (non sweep/date/chain/draw) dimensions by creating one subplot
1261
+ per combination of those dimensions - consistent with other plot_* methods.
1258
1262
1259
1263
Parameters
1260
1264
----------
1261
- results : xr.Dataset
1262
- The dataset containing the results of the sweep.
1263
- hdi_prob : float, optional
1264
- The probability for computing the highest density interval (HDI). Default is 0.94.
1265
- ax : Optional[plt.Axes], optional
1266
- An optional matplotlib Axes on which to plot. If None, a new Axes is created.
1267
- marginal : bool, optional
1268
- If True, plot marginal effects. If False (default), plot uplift.
1269
- percentage : bool, optional
1270
- If True, plot the results on the y-axis as percentages, instead of absolute
1271
- values. Default is False.
1265
+ hdi_prob : float, default 0.94
1266
+ HDI probability mass.
1267
+ ax : plt.Axes, optional
1268
+ Only used when there are no extra dimensions (single panel case).
1269
+ marginal : bool, default False
1270
+ Plot marginal effects instead of uplift.
1271
+ percentage : bool, default False
1272
+ Express uplift as a percentage of actual (not supported for marginal).
1273
+ sharey : bool, default True
1274
+ Share y-axis across subplots (only relevant for multi-panel case).
1272
1275
1273
1276
Returns
1274
1277
-------
1275
- plt.Axes
1276
- The Axes object with the plot.
1278
+ (fig, axes) if multi-panel, else a single Axes (backwards compatible single-dim case).
1277
1279
"""
1278
- if ax is None :
1279
- _ , ax = plt .subplots (figsize = (10 , 6 ))
1280
-
1281
1280
if percentage and marginal :
1282
1281
raise ValueError ("Not implemented marginal effects in percentage scale." )
1283
1282
1284
- # Check if sensitivity analysis results exist in idata
1285
1283
if not hasattr (self .idata , "sensitivity_analysis" ):
1286
1284
raise ValueError (
1287
1285
"No sensitivity analysis results found in 'self.idata'. "
1288
- "Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method."
1286
+ "Run 'mmm.sensitivity.run_sweep()' first."
1287
+ )
1288
+
1289
+ results : xr .Dataset = self .idata .sensitivity_analysis # type: ignore
1290
+
1291
+ # Required variable presence checks
1292
+ required_var = "marginal_effects" if marginal else "y"
1293
+ if required_var not in results :
1294
+ raise ValueError (
1295
+ f"Expected '{ required_var } ' in sensitivity_analysis results, found: { list (results .data_vars )} "
1296
+ )
1297
+ if "sweep" not in results .dims :
1298
+ raise ValueError (
1299
+ "Sensitivity analysis results must contain 'sweep' dimension."
1289
1300
)
1290
1301
1291
- # grab sensitivity analysis results from idata
1292
- results = self .idata .sensitivity_analysis
1293
-
1294
- x = results .sweep .values
1295
- if marginal :
1296
- y = results .marginal_effects .mean (dim = ["chain" , "draw" ]).sum (dim = "date" )
1297
- y_hdi = results .marginal_effects .sum (dim = "date" )
1298
- color = "C1"
1299
- label = "Posterior mean marginal effect"
1300
- title = "Marginal effects plot"
1301
- ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
1302
+ # Identify additional dimensions
1303
+ ignored_dims = {"chain" , "draw" , "date" , "sweep" }
1304
+ base_data = results .marginal_effects if marginal else results .y
1305
+ additional_dims = [d for d in base_data .dims if d not in ignored_dims ]
1306
+
1307
+ # Build all coordinate combinations
1308
+ if additional_dims :
1309
+ additional_coords = [results .coords [d ].values for d in additional_dims ]
1310
+ dim_combinations = list (itertools .product (* additional_coords ))
1302
1311
else :
1303
- if percentage :
1304
- actual = self .idata .posterior_predictive ["y" ]
1305
- y = results .y .mean (dim = ["chain" , "draw" ]).sum (dim = "date" ) / actual .mean (
1306
- dim = ["chain" , "draw" ]
1307
- ).sum (dim = "date" )
1308
- y_hdi = results .y .sum (dim = "date" ) / actual .sum (dim = "date" )
1309
- else :
1310
- y = results .y .mean (dim = ["chain" , "draw" ]).sum (dim = "date" )
1311
- y_hdi = results .y .sum (dim = "date" )
1312
- color = "C0"
1313
- label = "Posterior mean"
1314
- title = "Sensitivity analysis plot"
1315
- ylabel = "Total uplift (sum over dates)"
1316
-
1317
- ax .plot (x , y , label = label , color = color )
1318
-
1319
- az .plot_hdi (
1320
- x ,
1321
- y_hdi ,
1322
- hdi_prob = hdi_prob ,
1323
- color = color ,
1324
- fill_kwargs = {"alpha" : 0.5 , "label" : f"{ hdi_prob * 100 :.0f} % HDI" },
1325
- plot_kwargs = {"color" : color , "alpha" : 0.5 },
1326
- smooth = False ,
1327
- ax = ax ,
1328
- )
1312
+ dim_combinations = [()]
1313
+
1314
+ multi_panel = len (dim_combinations ) > 1
1315
+
1316
+ # If user provided ax but multiple panels needed, raise (consistent with other methods)
1317
+ if multi_panel and ax is not None :
1318
+ raise ValueError (
1319
+ "Cannot use 'ax' when there are extra dimensions. "
1320
+ "Let the function create its own subplots."
1321
+ )
1329
1322
1330
- ax .set (title = title )
1331
- if results .sweep_type == "absolute" :
1332
- ax .set_xlabel (f"Absolute value of: { results .var_names } " )
1323
+ # Prepare figure/axes
1324
+ if multi_panel :
1325
+ fig , axes = self ._init_subplots (n_subplots = len (dim_combinations ), ncols = 1 )
1326
+ if sharey :
1327
+ # Align y limits later - collect mins/maxs
1328
+ y_mins , y_maxs = [], []
1333
1329
else :
1334
- ax .set_xlabel (
1335
- f"{ results .sweep_type .capitalize ()} change of: { results .var_names } "
1330
+ if ax is None :
1331
+ fig , axes_arr = plt .subplots (figsize = (10 , 6 ))
1332
+ ax = axes_arr # type: ignore
1333
+ fig = ax .get_figure () # type: ignore
1334
+ axes = np .array ([[ax ]]) # type: ignore
1335
+
1336
+ sweep_values = results .coords ["sweep" ].values
1337
+
1338
+ # Helper: select subset (only dims present)
1339
+ def _select (data : xr .DataArray , indexers : dict ) -> xr .DataArray :
1340
+ valid = {k : v for k , v in indexers .items () if k in data .dims }
1341
+ return data .sel (** valid )
1342
+
1343
+ for row_idx , combo in enumerate (dim_combinations ):
1344
+ current_ax = axes [row_idx ][0 ] if multi_panel else ax # type: ignore
1345
+ indexers = (
1346
+ dict (zip (additional_dims , combo , strict = False ))
1347
+ if additional_dims
1348
+ else {}
1336
1349
)
1337
- ax .set_ylabel (ylabel )
1338
- plt .legend ()
1339
-
1340
- # Set y-axis limits based on the sign of y values
1341
- y_values = y .values if hasattr (y , "values" ) else np .array (y )
1342
- if np .all (y_values < 0 ):
1343
- ax .set_ylim (top = 0 )
1344
- elif np .all (y_values > 0 ):
1345
- ax .set_ylim (bottom = 0 )
1346
-
1347
- ax .yaxis .set_major_formatter (
1348
- plt .FuncFormatter (lambda x , _ : f"{ x :.1%} " if percentage else f"{ x :,.1f} " )
1349
- )
1350
1350
1351
- # Add reference lines
1352
- if results .sweep_type == "multiplicative" :
1353
- ax .axvline (x = 1 , color = "k" , linestyle = "--" , alpha = 0.5 )
1354
- if not marginal :
1355
- ax .axhline (y = 0 , color = "k" , linestyle = "--" , alpha = 0.5 )
1356
- elif results .sweep_type == "additive" :
1357
- ax .axvline (x = 0 , color = "k" , linestyle = "--" , alpha = 0.5 )
1351
+ if marginal :
1352
+ eff = _select (results .marginal_effects , indexers )
1353
+ # mean over chain/draw, sum over date (and any leftover dims not indexed)
1354
+ leftover = [d for d in eff .dims if d in ("date" ,) and d != "sweep" ]
1355
+ y_mean = eff .mean (dim = ["chain" , "draw" ]).sum (dim = leftover )
1356
+ y_hdi_data = eff .sum (dim = leftover )
1357
+ color = "C1"
1358
+ label = "Posterior mean marginal effect"
1359
+ title = "Marginal effects"
1360
+ ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
1361
+ else :
1362
+ y_da = _select (results .y , indexers )
1363
+ leftover = [d for d in y_da .dims if d in ("date" ,) and d != "sweep" ]
1364
+ if percentage :
1365
+ actual = self .idata .posterior_predictive ["y" ] # type: ignore
1366
+ actual_sel = _select (actual , indexers )
1367
+ actual_mean = actual_sel .mean (dim = ["chain" , "draw" ]).sum (
1368
+ dim = leftover
1369
+ )
1370
+ actual_sum = actual_sel .sum (dim = leftover )
1371
+ y_mean = (
1372
+ y_da .mean (dim = ["chain" , "draw" ]).sum (dim = leftover ) / actual_mean
1373
+ )
1374
+ y_hdi_data = y_da .sum (dim = leftover ) / actual_sum
1375
+ else :
1376
+ y_mean = y_da .mean (dim = ["chain" , "draw" ]).sum (dim = leftover )
1377
+ y_hdi_data = y_da .sum (dim = leftover )
1378
+ color = "C0"
1379
+ label = "Posterior mean uplift"
1380
+ title = "Sensitivity analysis"
1381
+ ylabel = "Total uplift (sum over dates)"
1382
+
1383
+ # Ensure ordering: y_mean dimension 'sweep'
1384
+ if "sweep" not in y_mean .dims :
1385
+ raise ValueError ("Expected 'sweep' dim after aggregation." )
1386
+
1387
+ current_ax .plot (sweep_values , y_mean , label = label , color = color ) # type: ignore
1388
+
1389
+ # Plot HDI
1390
+ az .plot_hdi (
1391
+ sweep_values ,
1392
+ y_hdi_data ,
1393
+ hdi_prob = hdi_prob ,
1394
+ color = color ,
1395
+ fill_kwargs = {"alpha" : 0.4 , "label" : f"{ hdi_prob * 100 :.0f} % HDI" },
1396
+ plot_kwargs = {"color" : color , "alpha" : 0.5 },
1397
+ smooth = False ,
1398
+ ax = current_ax ,
1399
+ )
1358
1400
1359
- return ax
1401
+ # Titles / labels
1402
+ if additional_dims :
1403
+ subplot_title = self ._build_subplot_title (
1404
+ additional_dims , combo , fallback_title = title
1405
+ )
1406
+ else :
1407
+ subplot_title = title
1408
+ current_ax .set_title (subplot_title ) # type: ignore
1409
+ if results .sweep_type == "absolute" :
1410
+ current_ax .set_xlabel (f"Absolute value of: { results .var_names } " ) # type: ignore
1411
+ else :
1412
+ current_ax .set_xlabel ( # type: ignore
1413
+ f"{ results .sweep_type .capitalize ()} change of: { results .var_names } "
1414
+ )
1415
+ current_ax .set_ylabel (ylabel ) # type: ignore
1416
+
1417
+ # Baseline reference lines
1418
+ if results .sweep_type == "multiplicative" :
1419
+ current_ax .axvline (x = 1 , color = "k" , linestyle = "--" , alpha = 0.5 ) # type: ignore
1420
+ if not marginal :
1421
+ current_ax .axhline (y = 0 , color = "k" , linestyle = "--" , alpha = 0.5 ) # type: ignore
1422
+ elif results .sweep_type == "additive" :
1423
+ current_ax .axvline (x = 0 , color = "k" , linestyle = "--" , alpha = 0.5 ) # type: ignore
1424
+
1425
+ # Format y
1426
+ if percentage :
1427
+ current_ax .yaxis .set_major_formatter ( # type: ignore
1428
+ plt .FuncFormatter (lambda v , _ : f"{ v :.1%} " ) # type: ignore
1429
+ )
1430
+ else :
1431
+ current_ax .yaxis .set_major_formatter ( # type: ignore
1432
+ plt .FuncFormatter (lambda v , _ : f"{ v :,.1f} " ) # type: ignore
1433
+ )
1434
+
1435
+ # Adjust y-lims sign aware
1436
+ y_vals = y_mean .values
1437
+ if np .all (y_vals < 0 ):
1438
+ current_ax .set_ylim (top = 0 ) # type: ignore
1439
+ elif np .all (y_vals > 0 ):
1440
+ current_ax .set_ylim (bottom = 0 ) # type: ignore
1441
+
1442
+ if multi_panel and sharey :
1443
+ y_mins .append (current_ax .get_ylim ()[0 ]) # type: ignore
1444
+ y_maxs .append (current_ax .get_ylim ()[1 ]) # type: ignore
1445
+
1446
+ current_ax .legend (loc = "best" ) # type: ignore
1447
+
1448
+ # Share y limits if requested
1449
+ if multi_panel and sharey :
1450
+ global_min , global_max = min (y_mins ), max (y_maxs )
1451
+ for row_idx in range (len (dim_combinations )):
1452
+ axes [row_idx ][0 ].set_ylim (global_min , global_max )
1453
+
1454
+ if multi_panel :
1455
+ fig .tight_layout ()
1456
+ return fig , axes
1457
+ else :
1458
+ return ax # single axis for backwards compatibility
0 commit comments