Skip to content

Commit 8a177f2

Browse files
ColCarrolltwiecki
authored andcommitted
Fix HamiltonianMC init to forward arguments (#3469)
* Fix HamiltonianMC init to forward arguments * Update first 100 steps of HMC * Comments
1 parent 55ccacf commit 8a177f2

File tree

3 files changed

+117
-109
lines changed

3 files changed

+117
-109
lines changed

RELEASE-NOTES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
- Added the `distributions.shape_utils` module with functions used to help broadcast samples drawn from distributions using the `size` keyword argument.
1212
- Used `numpy.vectorize` in `distributions.distribution._compile_theano_function`. This enables `sample_prior_predictive` and `sample_posterior_predictive` to ask for tuples of samples instead of just integers. This fixes issue #3422.
1313

14+
### Fixes
15+
16+
- `HamiltonianMC` was ignoring certain arguments like `target_accept`, and not using the custom step size jitter function with expectation 1.
17+
1418
### Maintenance
19+
1520
- All occurances of `sd` as a parameter name have been renamed to `sigma`. `sd` will continue to function for backwards compatibility.
1621
- Made `BrokenPipeError` for parallel sampling more verbose on Windows.
1722
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
@@ -39,6 +44,7 @@
3944
- Fixed a defect found in `Bound.random_` where `total_size` could end up as a `float64` instead of being an integer if given `size=tuple()`.
4045
- Fixed an issue in `model_graph` that caused construction of the graph of the model for rendering to hang: replaced a search over the powerset of the nodes with a breadth-first search over the nodes. Fix for #3458.
4146
- Removed variable annotations from `model_graph` but left type hints (Fix for #3465). This means that we support `python>=3.5.4`.
47+
- Default `target_accept`for `HamiltonianMC` is now 0.65, as suggested in Beskos et. al. 2010 and Neal 2001.
4248

4349
### Deprecations
4450

pymc3/step_methods/hmc/hmc.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ class HamiltonianMC(BaseHMC):
3737
'model_logp': np.float64,
3838
}]
3939

40-
def __init__(self, vars=None, path_length=2.,
41-
adapt_step_size=True, gamma=0.05, k=0.75, t0=10,
42-
target_accept=0.8, **kwargs):
40+
def __init__(self, vars=None, path_length=2., **kwargs):
4341
"""Set up the Hamiltonian Monte Carlo sampler.
4442
4543
Parameters
@@ -63,19 +61,22 @@ def __init__(self, vars=None, path_length=2.,
6361
An object that represents the Hamiltonian with methods `velocity`,
6462
`energy`, and `random` methods. It can be specified instead
6563
of the scaling matrix.
66-
target_accept : float, default .8
64+
target_accept : float, default 0.65
6765
Adapt the step size such that the average acceptance
6866
probability across the trajectories are close to target_accept.
6967
Higher values for target_accept lead to smaller step sizes.
7068
Setting this to higher values like 0.9 or 0.99 can help
7169
with sampling from difficult posteriors. Valid values are
72-
between 0 and 1 (exclusive).
70+
between 0 and 1 (exclusive). Default of 0.65 is from (Beskos et.
71+
al. 2010, Neal 2011). See Hoffman and Gelman's "The No-U-Turn
72+
Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte
73+
Carlo" section 3.2 for details.
7374
gamma : float, default .05
7475
k : float, default .75
7576
Parameter for dual averaging for step size adaptation. Values
7677
between 0.5 and 1 (exclusive) are admissible. Higher values
7778
correspond to slower adaptation.
78-
t0 : int, default 10
79+
t0 : float > 0, default 10
7980
Parameter for dual averaging. Higher values slow initial
8081
adaptation.
8182
adapt_step_size : bool, default=True
@@ -85,12 +86,13 @@ def __init__(self, vars=None, path_length=2.,
8586
The model
8687
**kwargs : passed to BaseHMC
8788
"""
89+
kwargs.setdefault('step_rand', unif)
90+
kwargs.setdefault('target_accept', 0.65)
8891
super().__init__(vars, **kwargs)
8992
self.path_length = path_length
9093

9194
def _hamiltonian_step(self, start, p0, step_size):
92-
path_length = np.random.rand() * self.path_length
93-
n_steps = max(1, int(path_length / step_size))
95+
n_steps = max(1, int(self.path_length / step_size))
9496

9597
energy_change = -np.inf
9698
state = start
@@ -120,7 +122,7 @@ def _hamiltonian_step(self, start, p0, step_size):
120122
accepted = True
121123

122124
stats = {
123-
'path_length': path_length,
125+
'path_length': self.path_length,
124126
'n_steps': n_steps,
125127
'accept': accept_stat,
126128
'energy_error': energy_change,

pymc3/tests/test_step.py

Lines changed: 100 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -155,106 +155,106 @@ class TestStepMethods: # yield test doesn't work subclassing object
155155
),
156156
HamiltonianMC: np.array(
157157
[
158-
0.43733634,
159-
0.43733634,
160-
0.15955614,
161-
-0.44355329,
162-
0.21465731,
163-
0.30148244,
164-
0.45527282,
165-
0.45527282,
166-
0.41753005,
167-
-0.03480236,
168-
1.16599611,
169-
0.565306,
170-
0.565306,
171-
0.0077143,
172-
-0.18291321,
173-
-0.14577946,
174-
-0.00703353,
175-
-0.00703353,
176-
0.14345194,
177-
-0.12345058,
178-
0.76875516,
179-
0.76875516,
180-
0.84289506,
181-
0.24596225,
182-
0.95287087,
183-
1.3799335,
184-
1.1493899,
185-
1.1493899,
186-
2.0255982,
187-
-0.77850273,
188-
0.11604115,
189-
0.11604115,
190-
0.39296557,
191-
0.34826491,
192-
0.5951183,
193-
0.63097341,
194-
0.57938784,
195-
0.57938784,
196-
0.76570029,
197-
0.63516046,
198-
0.23667784,
199-
2.0151377,
200-
1.92064966,
201-
1.09125654,
202-
-0.43716787,
203-
0.61939595,
204-
0.30566853,
205-
0.30566853,
206-
0.3690641,
207-
0.3690641,
208-
0.3690641,
209-
1.26497542,
210-
0.90890334,
211-
0.01482818,
212-
0.01482818,
213-
-0.15542473,
214-
0.26475651,
215-
0.32687263,
216-
1.21902207,
217-
0.6708017,
218-
-0.18867695,
219-
-0.18867695,
220-
-0.07141329,
221-
-0.04631175,
222-
-0.16855462,
223-
-0.16855462,
224-
1.05455573,
225-
0.47371825,
226-
0.47371825,
227-
0.86307077,
228-
0.86307077,
229-
0.51484125,
230-
1.0022533,
231-
1.0022533,
232-
1.02370316,
233-
0.71331829,
234-
0.71331829,
235-
0.71331829,
236-
0.40758664,
237-
0.81307434,
238-
-0.46269741,
239-
-0.60284666,
240-
0.06710527,
241-
0.06710527,
242-
-0.35055053,
243-
0.36727629,
244-
0.36727629,
245-
0.69350367,
246-
0.11268647,
247-
0.37681301,
248-
1.10168386,
249-
0.49559472,
250-
0.49559472,
251-
0.06193658,
252-
-0.07947103,
253-
0.01969434,
254-
1.28470893,
255-
-0.13536813,
256-
-0.13536813,
257-
0.6575966,
158+
1.43583525,
159+
1.43583525,
160+
1.43583525,
161+
-0.57415005,
162+
0.91472062,
163+
0.91472062,
164+
0.36282799,
165+
0.80991631,
166+
0.84457253,
167+
0.84457253,
168+
-0.12651784,
169+
-0.12651784,
170+
0.39027088,
171+
-0.22998424,
172+
0.64337475,
173+
0.64337475,
174+
0.03504003,
175+
1.2667789,
176+
1.2667789,
177+
0.34770874,
178+
0.224319,
179+
0.224319,
180+
1.00416894,
181+
0.46161403,
182+
0.28217305,
183+
0.28217305,
184+
0.50327811,
185+
0.50327811,
186+
0.50327811,
187+
0.50327811,
188+
0.42335724,
189+
0.42335724,
190+
0.20336198,
191+
0.20336198,
192+
0.20336198,
193+
0.16330229,
194+
0.16330229,
195+
-0.7332075,
196+
1.04924226,
197+
1.04924226,
198+
0.39630439,
199+
0.16481719,
200+
0.16481719,
201+
0.84146061,
202+
0.83146709,
203+
0.83146709,
204+
0.32748059,
205+
1.00918804,
206+
1.00918804,
207+
0.91034823,
208+
1.31278027,
209+
1.38222654,
210+
1.38222654,
211+
-0.32268814,
212+
-0.32268814,
213+
2.1866116,
214+
1.21679252,
215+
-0.15916878,
216+
-0.15916878,
217+
0.38958249,
218+
0.38958249,
219+
0.54971928,
220+
0.05591406,
221+
0.87712017,
222+
0.87712017,
223+
0.19409043,
224+
0.19409043,
225+
0.19409043,
226+
0.40718849,
227+
0.63399349,
228+
0.35510353,
229+
0.35510353,
230+
0.47860847,
231+
0.47860847,
232+
0.69805772,
233+
0.16686305,
234+
0.16686305,
235+
0.16686305,
236+
0.04971251,
237+
0.04971251,
238+
-0.90052793,
239+
-0.73203754,
240+
1.02258958,
241+
1.02258958,
242+
-0.14144856,
243+
-0.14144856,
244+
1.43017486,
245+
1.23202605,
246+
1.23202605,
247+
0.24442885,
248+
0.78300516,
249+
0.30494261,
250+
0.30494261,
251+
0.30494261,
252+
-0.00596443,
253+
1.31695235,
254+
0.81375848,
255+
0.81375848,
256+
0.81375848,
257+
1.91238675
258258
]
259259
),
260260
Metropolis: np.array(

0 commit comments

Comments
 (0)