Skip to content
14 changes: 12 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
SeedSequenceSeed,
compile,
find_rng_nodes,
replace_rng_nodes,
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[SeedSequenceSeed], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +130,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[SeedSequenceSeed], PointType]:
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +144,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
4 changes: 2 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import types
import warnings

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import (
Literal,
cast,
Expand Down Expand Up @@ -585,7 +585,7 @@ def compile_logp(
jacobian: bool = True,
sum: bool = True,
**compile_kwargs,
) -> PointFunc:
) -> Callable[[PointType], np.ndarray]:
"""Compiled log probability density function.

Parameters
Expand Down
Loading