@@ -116,11 +116,10 @@ def pathfinder_result_to_xarray(
116
116
>>> with pm.Model() as model:
117
117
... x = pm.Normal("x", 0, 1)
118
118
... y = pm.Normal("y", x, 1, observed=2.0)
119
- ...
120
119
>>> # Assuming we have a PathfinderResult from a pathfinder run
121
120
>>> ds = pathfinder_result_to_xarray(result, model=model)
122
121
>>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
123
- >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
122
+ >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
124
123
"""
125
124
data_vars = {}
126
125
coords = {}
@@ -214,9 +213,16 @@ def multipathfinder_result_to_xarray(
214
213
>>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
215
214
>>> ds = multipathfinder_result_to_xarray(result, model=model)
216
215
>>> print("All data:", ds.data_vars)
217
- >>> print("Summary:", [k for k in ds.data_vars.keys() if not k.startswith(('paths/', 'config/', 'diagnostics/'))])
218
- >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith('paths/')])
219
- >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith('config/')])
216
+ >>> print(
217
+ ... "Summary:",
218
+ ... [
219
+ ... k
220
+ ... for k in ds.data_vars.keys()
221
+ ... if not k.startswith(("paths/", "config/", "diagnostics/"))
222
+ ... ],
223
+ ... )
224
+ >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
225
+ >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
220
226
"""
221
227
n_params = result .samples .shape [- 1 ] if result .samples is not None else None
222
228
param_coords = get_param_coords (model , n_params ) if n_params is not None else None
@@ -477,13 +483,16 @@ def add_pathfinder_to_inference_data(
477
483
>>> with pm.Model() as model:
478
484
... x = pm.Normal("x", 0, 1)
479
485
... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
480
- ...
481
486
>>> # Assuming we have pathfinder results
482
487
>>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
483
488
>>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
484
489
>>> # Access nested data:
485
- >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('paths/')]) # Per-path data
486
- >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('config/')]) # Config data
490
+ >>> print(
491
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
492
+ ... ) # Per-path data
493
+ >>> print(
494
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
495
+ ... ) # Config data
487
496
"""
488
497
# Detect if this is a multi-path result
489
498
# Use isinstance() as primary check, but fall back to duck typing for compatibility
0 commit comments