Skip to content

Conversation

@andrewdipper
Copy link

Fix examples after change to extend_params in blackjax-devs/blackjax#694.

Additionally in the TemperedSMC example max_num_doublings was changed to 6 instead of the default 10 since we regularly hit max_num_doublings due to the small step size (I believe this is for illustrative purposes). On a gpu device the example is extraordinarily slow without the change - and still takes ~2 mins with it. It seems far too slow but I haven't been able to find any explanation.

For reference:

CPU (10000 samples, max_num_doublings=10):
step_size = 1e-2:
HMC: 50 steps / 1.14s
NUTS: 30 steps / .964s

step_size = 1e-3
HMC: 50 / 1.14s
NUTS: 273 / 1.9s

step_size = 1e-4
HMC: 50 / 1.18s
NUTS: 926 / 4.23s

GPU (1000 samples - 10x fewer samples..., max_num_doublings=10):
step_size = 1e-2:
HMC: 50 / 3.31s
NUTS: 30 / 7.3s

step_size = 1e-3:
HMC: 50 / 3.32s
NUTS: 267 / 63s

step_size = 1e-4
HMC: 50 / 3.31s
NUTS: 926.4 / 237s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants