2626
2727from pymc .logprob .transforms import Transform
2828from pymc .pytensorf import (
29+ SeedSequenceSeed ,
2930 compile ,
3031 find_rng_nodes ,
3132 replace_rng_nodes ,
@@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
6768 overrides : StartDict | Sequence [StartDict | None ] | None ,
6869 jitter_rvs : set [TensorVariable ] | None = None ,
6970 chains : int ,
70- ) -> list [Callable [[int ], PointType ]]:
71+ ) -> list [Callable [[SeedSequenceSeed ], PointType ]]:
7172 """Create an initial point function for each chain, as defined by initvals.
7273
7374 If a single initval dictionary is passed, the function is replicated for each
@@ -84,7 +85,7 @@ def make_initial_point_fns_per_chain(
8485
8586 Returns
8687 -------
87- ipfns : list[Callable[[int ], dict[str, np.ndarray]]]
88+ ipfns : list[Callable[[SeedSequenceSeed ], dict[str, np.ndarray]]]
8889 list of functions that return initial points for each chain.
8990
9091 Raises
@@ -129,7 +130,7 @@ def make_initial_point_fn(
129130 jitter_rvs : set [TensorVariable ] | None = None ,
130131 default_strategy : str = "support_point" ,
131132 return_transformed : bool = True ,
132- ) -> Callable [[int ], PointType ]:
133+ ) -> Callable [[SeedSequenceSeed ], PointType ]:
133134 """Create seeded function that computes initial values for all free model variables.
134135
135136 Parameters
@@ -146,7 +147,7 @@ def make_initial_point_fn(
146147
147148 Returns
148149 -------
149- initial_point_fn : Callable[[int ], dict[str, np.ndarray]]
150+ initial_point_fn : Callable[[SeedSequenceSeed ], dict[str, np.ndarray]]
150151 """
151152 sdict_overrides = convert_str_to_rv_dict (model , overrides or {})
152153 initval_strats = {
0 commit comments