File tree Expand file tree Collapse file tree 2 files changed +4
-27
lines changed Expand file tree Collapse file tree 2 files changed +4
-27
lines changed Original file line number Diff line number Diff line change @@ -52,29 +52,11 @@ def convert_str_to_rv_dict(
52
52
return initvals
53
53
54
54
55
- def filter_rvs_to_jitter (step ) -> Set [TensorVariable ]:
56
- """Find the set of RVs for which the responsible step methods ask for
57
- the addition of jitter to the initial point.
58
-
59
- Parameters
60
- ----------
61
- step : BlockedStep or CompoundStep
62
- One or many step methods that were assigned model variables.
63
-
64
- Returns
65
- -------
66
- rvs_to_jitter : set
67
- The random variables for which jitter should be added.
68
- """
69
- # TODO: implement this
70
- return set ()
71
-
72
-
73
55
def make_initial_point_fns_per_chain (
74
56
* ,
75
57
model ,
76
58
overrides : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]],
77
- jitter_rvs : Set [TensorVariable ],
59
+ jitter_rvs : Optional [ Set [TensorVariable ]] = None ,
78
60
chains : int ,
79
61
) -> List [Callable ]:
80
62
"""Create an initial point function for each chain, as defined by initvals
@@ -87,7 +69,7 @@ def make_initial_point_fns_per_chain(
87
69
overrides : optional, list or dict
88
70
Initial value strategy overrides that should take precedence over the defaults from the model.
89
71
A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`.
90
- jitter_rvs : set
72
+ jitter_rvs : set, optional
91
73
Random variable tensors for which U(-1, 1) jitter shall be applied.
92
74
(To the transformed space if applicable.)
93
75
Original file line number Diff line number Diff line change 37
37
from pymc .backends .base import BaseTrace , MultiTrace , _choose_chains
38
38
from pymc .blocking import DictToArrayBijection
39
39
from pymc .exceptions import SamplingError
40
- from pymc .initial_point import (
41
- PointType ,
42
- StartDict ,
43
- filter_rvs_to_jitter ,
44
- make_initial_point_fns_per_chain ,
45
- )
40
+ from pymc .initial_point import PointType , StartDict , make_initial_point_fns_per_chain
46
41
from pymc .model import Model , modelcontext
47
42
from pymc .sampling .parallel import Draw , _cpu_count
48
43
from pymc .sampling .population import _sample_population
@@ -476,7 +471,7 @@ def sample(
476
471
ipfns = make_initial_point_fns_per_chain (
477
472
model = model ,
478
473
overrides = initvals ,
479
- jitter_rvs = filter_rvs_to_jitter ( step ),
474
+ jitter_rvs = set ( ),
480
475
chains = chains ,
481
476
)
482
477
initial_points = [ipfn (seed ) for ipfn , seed in zip (ipfns , random_seed_list )]
You can’t perform that action at this time.
0 commit comments