|
22 | 22 | Any, |
23 | 23 | TypeAlias, |
24 | 24 | cast, |
| 25 | + overload, |
25 | 26 | ) |
26 | 27 |
|
27 | 28 | import numpy as np |
@@ -360,6 +361,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None): |
360 | 361 | ] |
361 | 362 |
|
362 | 363 |
|
| 364 | +@overload |
| 365 | +def sample_prior_predictive( |
| 366 | + draws: int = 500, |
| 367 | + model: Model | None = None, |
| 368 | + var_names: Iterable[str] | None = None, |
| 369 | + random_seed: RandomState = None, |
| 370 | + return_inferencedata: bool = True, |
| 371 | + idata_kwargs: dict | None = None, |
| 372 | + compile_kwargs: dict | None = None, |
| 373 | + samples: int | None = None, |
| 374 | +) -> InferenceData: ... |
| 375 | +@overload |
| 376 | +def sample_prior_predictive( |
| 377 | + draws: int = 500, |
| 378 | + model: Model | None = None, |
| 379 | + var_names: Iterable[str] | None = None, |
| 380 | + random_seed: RandomState = None, |
| 381 | + return_inferencedata: bool = False, |
| 382 | + idata_kwargs: dict | None = None, |
| 383 | + compile_kwargs: dict | None = None, |
| 384 | + samples: int | None = None, |
| 385 | +) -> dict[str, np.ndarray]: ... |
363 | 386 | def sample_prior_predictive( |
364 | 387 | draws: int = 500, |
365 | 388 | model: Model | None = None, |
@@ -449,7 +472,7 @@ def sample_prior_predictive( |
449 | 472 | ) |
450 | 473 |
|
451 | 474 | # All model variables have a name, but mypy does not know this |
452 | | - _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] |
| 475 | + _log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] |
453 | 476 | values = zip(*(sampler_fn() for i in range(draws))) |
454 | 477 |
|
455 | 478 | data = {k: np.stack(v) for k, v in zip(names, values)} |
|
0 commit comments