Skip to content

Commit b512717

Browse files
author
Goose
committed
correct sampled variables in sample_prior_predictive log call & add return type overloads
1 parent d1aff0b commit b512717

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

pymc/sampling/forward.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Any,
2323
TypeAlias,
2424
cast,
25+
overload,
2526
)
2627

2728
import numpy as np
@@ -360,6 +361,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None):
360361
]
361362

362363

364+
@overload
365+
def sample_prior_predictive(
366+
draws: int = 500,
367+
model: Model | None = None,
368+
var_names: Iterable[str] | None = None,
369+
random_seed: RandomState = None,
370+
return_inferencedata: bool = True,
371+
idata_kwargs: dict | None = None,
372+
compile_kwargs: dict | None = None,
373+
samples: int | None = None,
374+
) -> InferenceData: ...
375+
@overload
376+
def sample_prior_predictive(
377+
draws: int = 500,
378+
model: Model | None = None,
379+
var_names: Iterable[str] | None = None,
380+
random_seed: RandomState = None,
381+
return_inferencedata: bool = False,
382+
idata_kwargs: dict | None = None,
383+
compile_kwargs: dict | None = None,
384+
samples: int | None = None,
385+
) -> dict[str, np.ndarray]: ...
363386
def sample_prior_predictive(
364387
draws: int = 500,
365388
model: Model | None = None,
@@ -449,7 +472,7 @@ def sample_prior_predictive(
449472
)
450473

451474
# All model variables have a name, but mypy does not know this
452-
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
475+
_log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
453476
values = zip(*(sampler_fn() for i in range(draws)))
454477

455478
data = {k: np.stack(v) for k, v in zip(names, values)}

0 commit comments

Comments
 (0)