Skip to content

Commit 3996a06

Browse files
author
Goose
committed
correcting initial point type hinting
1 parent 31bf864 commit 3996a06

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc/initial_point.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pymc.logprob.transforms import Transform
2828
from pymc.pytensorf import (
29+
SeedSequenceSeed,
2930
compile,
3031
find_rng_nodes,
3132
replace_rng_nodes,
@@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
6768
overrides: StartDict | Sequence[StartDict | None] | None,
6869
jitter_rvs: set[TensorVariable] | None = None,
6970
chains: int,
70-
) -> list[Callable[[int], PointType]]:
71+
) -> list[Callable[[SeedSequenceSeed], PointType]]:
7172
"""Create an initial point function for each chain, as defined by initvals.
7273
7374
If a single initval dictionary is passed, the function is replicated for each
@@ -84,7 +85,7 @@ def make_initial_point_fns_per_chain(
8485
8586
Returns
8687
-------
87-
ipfns : list[Callable[[int], dict[str, np.ndarray]]]
88+
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
8889
list of functions that return initial points for each chain.
8990
9091
Raises
@@ -129,7 +130,7 @@ def make_initial_point_fn(
129130
jitter_rvs: set[TensorVariable] | None = None,
130131
default_strategy: str = "support_point",
131132
return_transformed: bool = True,
132-
) -> Callable[[int], PointType]:
133+
) -> Callable[[SeedSequenceSeed], PointType]:
133134
"""Create seeded function that computes initial values for all free model variables.
134135
135136
Parameters
@@ -146,7 +147,7 @@ def make_initial_point_fn(
146147
147148
Returns
148149
-------
149-
initial_point_fn : Callable[[int], dict[str, np.ndarray]]
150+
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
150151
"""
151152
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
152153
initval_strats = {

0 commit comments

Comments
 (0)