Skip to content

Commit b6517af

Browse files
committed
Added docstrings and simplified tests
1 parent e0048bf commit b6517af

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

pymc/initial_point.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def make_initial_point_fns_per_chain(
7171
) -> list[Callable]:
7272
"""Create an initial point function for each chain, as defined by initvals.
7373
74-
If a single initval dictionary is passed, the function is replicated for each
75-
chain, otherwise a unique function is compiled for each entry in the dictionary.
74+
If a single initval dictionary is passed, the function is replicated for each chain, otherwise a unique function is compiled for each entry in the dictionary.
7675
7776
Parameters
7877
----------
@@ -82,6 +81,8 @@ def make_initial_point_fns_per_chain(
8281
jitter_rvs : set, optional
8382
Random variable tensors for which U(-1, 1) jitter shall be applied.
8483
(To the transformed space if applicable.)
84+
jitter_scale : float, optional
85+
The scale of the jitter in the jitter_rvs set. Defaults to 1.0.
8586
8687
Raises
8788
------
@@ -134,8 +135,9 @@ def make_initial_point_fn(
134135
Parameters
135136
----------
136137
jitter_rvs : set
137-
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
138-
added to the initial value. Only available for variables that have a transform or real-valued support.
138+
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support.
139+
jitter_scale : float, optional
140+
The scale of the jitter in the jitter_rvs set. Defaults to 1.0.
139141
default_strategy : str
140142
Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None.
141143
overrides : dict
@@ -209,8 +211,10 @@ def make_initial_point_expression(
209211
Mapping of free random variable tensors to initial value strategies.
210212
For example the `Model.initial_values` dictionary.
211213
jitter_rvs : set
212-
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
214+
The set (or list or tuple) of random variables for which a U(-1, 1) jitter should be
213215
added to the initial value. Only available for variables that have a transform or real-valued support.
216+
jitter_scale : float, optional
217+
The scale of the jitter in the jitter_rvs set. Defaults to 1.0.
214218
default_strategy : str
215219
Which of { "support_point", "prior" } to prefer if the initval strategy setting for an RV is None.
216220
return_transformed : bool

pymc/sampling/mcmc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,8 @@ def _init_jitter(
14381438
----------
14391439
jitter: bool
14401440
Whether to apply jitter or not.
1441+
jitter_scale : float, optional
1442+
The scale of the jitter in set(model.free_RVs). Defaults to 1.0.
14411443
jitter_max_retries : int
14421444
Maximum number of repeated attempts at initializing values (per chain).
14431445

tests/test_initial_point.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -156,29 +156,20 @@ def test_jitter_scale(self):
156156
with pm.Model() as pmodel:
157157
A = pm.HalfFlat("A", initval="support_point")
158158

159-
jitter_scale_tests = np.array([1.0, 2.0, 5.0])
160-
fns = []
161-
for jitter_scale in jitter_scale_tests:
162-
fns.append(
163-
make_initial_point_fn(
164-
model=pmodel,
165-
jitter_rvs=set(pmodel.free_RVs),
166-
jitter_scale=jitter_scale,
167-
return_transformed=True,
168-
)
169-
)
170-
171-
n_draws = 1000
172-
jitter_samples = np.empty((n_draws, len(fns)))
173-
for j, fn in enumerate(fns):
174-
# start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other
175-
start = j * n_draws
176-
end = start + n_draws
177-
jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)])
159+
fn_default = make_initial_point_fn(
160+
model=pmodel,
161+
jitter_rvs=set(pmodel.free_RVs),
162+
return_transformed=True,
163+
)
178164

179-
init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0)
165+
fn_large = make_initial_point_fn(
166+
model=pmodel,
167+
jitter_rvs=set(pmodel.free_RVs),
168+
jitter_scale=1000.0,
169+
return_transformed=True,
170+
)
180171

181-
assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05))
172+
assert fn_large(0)["A_log__"] > fn_default(0)["A_log__"]
182173

183174
def test_respects_overrides(self):
184175
with pm.Model() as pmodel:

0 commit comments

Comments
 (0)