|
| 1 | +--- |
| 2 | +title: "MCMC Sampling Options" |
| 3 | +engine: julia |
| 4 | +--- |
| 5 | + |
| 6 | +```{julia} |
| 7 | +#| echo: false |
| 8 | +#| output: false |
| 9 | +using Pkg; |
| 10 | +Pkg.instantiate(); |
| 11 | +``` |
| 12 | + |
| 13 | +Markov chain Monte Carlo sampling in Turing.jl is performed using the `sample()` function. |
| 14 | +As described on the [Core Functionality page]({{< meta core-functionality >}}), single-chain and multiple-chain sampling can be done using, respectively, |
| 15 | + |
| 16 | +```julia |
| 17 | +sample(model, sampler, niters) |
| 18 | +sample(model, sampler, MCMCThreads(), niters, nchains) # or MCMCSerial() or MCMCDistributed() |
| 19 | +``` |
| 20 | + |
| 21 | +On top of this, both methods also accept a number of keyword arguments that allow you to control the sampling process. |
| 22 | +This page will detail these options. |
| 23 | + |
| 24 | +To begin, let's create a simple model: |
| 25 | + |
| 26 | +```{julia} |
| 27 | +using Turing |
| 28 | +
|
| 29 | +@model function demo_model() |
| 30 | + x ~ Normal() |
| 31 | + y ~ Normal(x) |
| 32 | + 4.0 ~ Normal(y) |
| 33 | + return nothing |
| 34 | +end |
| 35 | +``` |
| 36 | + |
| 37 | +## Controlling logging |
| 38 | + |
| 39 | +Progress bars can be controlled with the `progress` keyword argument. |
| 40 | +The exact values that can be used depend on whether you are using single-chain or multi-chain sampling. |
| 41 | + |
| 42 | +For single-chain sampling, `progress=true` and `progress=false` enable and disable the progress bar, respectively. |
| 43 | + |
| 44 | +For multi-chain sampling, `progress` can take the following values: |
| 45 | + |
| 46 | +- `:none` or `false`: no progress bar |
| 47 | +- (default) `:overall` or `true`: creates one overall progress bar for all chains |
| 48 | +- `:perchain`: creates one overall progress bar, plus one extra progress bar per chain (note that this can lead to visual clutter if you have many chains) |
| 49 | + |
| 50 | +If you want to globally enable or disable the progress bar, you can use: |
| 51 | + |
| 52 | +```{julia} |
| 53 | +Turing.setprogress!(false); # or true |
| 54 | +``` |
| 55 | + |
| 56 | +(This handily also disables progress logging for the rest of this document.) |
| 57 | + |
| 58 | +For NUTS in particular, you can also specify `verbose=false` to disable the "Found initial step size" info message. |
| 59 | + |
| 60 | +## Ensuring sampling reproducibility |
| 61 | + |
| 62 | +Like many other Julia functions, a `Random.AbstractRNG` object can be passed as the first argument to `sample()` to ensure reproducibility of results. |
| 63 | + |
| 64 | +```{julia} |
| 65 | +using Random |
| 66 | +chn1 = sample(Xoshiro(468), demo_model(), MH(), 5); |
| 67 | +chn2 = sample(Xoshiro(468), demo_model(), MH(), 5); |
| 68 | +(chn1[:x] == chn2[:x], chn1[:y] == chn2[:y]) |
| 69 | +``` |
| 70 | + |
| 71 | +Alternatively, you can set the global RNG using `Random.seed!()`, although we recommend this less as it modifies global state. |
| 72 | + |
| 73 | +```{julia} |
| 74 | +Random.seed!(468) |
| 75 | +chn3 = sample(demo_model(), MH(), 5); |
| 76 | +Random.seed!(468) |
| 77 | +chn4 = sample(demo_model(), MH(), 5); |
| 78 | +(chn3[:x] == chn4[:x], chn3[:y] == chn4[:y]) |
| 79 | +``` |
| 80 | + |
| 81 | +::: {.callout-note} |
| 82 | +The outputs of pseudorandom number generators in the standard `Random` library are not guaranteed to be the same across different Julia versions or platforms. |
| 83 | +If you require absolute reproducibility, you should use [the StableRNGs.jl package](https://github.com/JuliaRandom/StableRNGs.jl). |
| 84 | +::: |
| 85 | + |
| 86 | +## Switching the output type |
| 87 | + |
| 88 | +By default, the results of MCMC sampling are bundled up in an `MCMCChains.Chains` object. |
| 89 | + |
| 90 | +```{julia} |
| 91 | +chn = sample(demo_model(), HMC(0.1, 20), 5) |
| 92 | +typeof(chn) |
| 93 | +``` |
| 94 | + |
| 95 | +If you wish to use a different chain format provided in another package, you can specify the `chain_type` keyword argument. |
| 96 | +You should refer to the documentation of the respective package for exact details. |
| 97 | + |
| 98 | +Another situation where specifying `chain_type` can be useful is when you want to obtain the raw MCMC outputs as a vector of transitions. |
| 99 | +This can be used for profiling or debugging purposes (often, chain construction can take a surprising amount of time compared to sampling, especially for very simple models). |
| 100 | +To do so, you can use `chain_type=Any` (i.e., do not convert the output to any specific chain format): |
| 101 | + |
| 102 | +```{julia} |
| 103 | +transitions = sample(demo_model(), MH(), 5; chain_type=Any) |
| 104 | +typeof(transitions) |
| 105 | +``` |
| 106 | + |
| 107 | +## Specifying initial parameters |
| 108 | + |
| 109 | +In Turing.jl, initial parameters for MCMC sampling can be specified using the `initial_params` keyword argument. |
| 110 | + |
| 111 | +For **single-chain sampling**, the AbstractMCMC interface generally expects that you provide a completely flattened vector of parameters. |
| 112 | + |
| 113 | +```{julia} |
| 114 | +chn = sample(demo_model(), MH(), 5; initial_params=[1.0, -5.0]) |
| 115 | +chn[:x][1], chn[:y][1] |
| 116 | +``` |
| 117 | + |
| 118 | +::: {.callout-note} |
| 119 | +Note that a number of samplers use warm-up steps by default (see the [Thinning and Warmup section below](#thinning-and-warmup)), so `chn[:param][1]` may not correspond to the exact initial parameters you provided. |
| 120 | +`MH()` does not do this, which is why we use it here. |
| 121 | +::: |
| 122 | + |
| 123 | +Note that for Turing models, the use of `Vector` can be extremely error-prone as the order of parameters in the flattened vector is not always obvious (especially if there are parameters with non-trivial types). |
| 124 | +In general, parameters should be provided in the order they are defined in the model. |
| 125 | +A relatively 'safe' way of obtaining parameters in the correct order is to first generate a `VarInfo`, and then linearise that: |
| 126 | + |
| 127 | +```{julia} |
| 128 | +using DynamicPPL |
| 129 | +vi = VarInfo(demo_model()) |
| 130 | +initial_params = vi[:] |
| 131 | +``` |
| 132 | + |
| 133 | +To avoid this situation, you can also use `NamedTuple` to specify initial parameters. |
| 134 | + |
| 135 | +```{julia} |
| 136 | +chn = sample(demo_model(), MH(), 5; initial_params=(y=2.0, x=-6.0)) |
| 137 | +chn[:x][1], chn[:y][1] |
| 138 | +``` |
| 139 | + |
| 140 | +This works even for parameters with more complex types. |
| 141 | + |
| 142 | +```{julia} |
| 143 | +@model function demo_complex() |
| 144 | + x ~ LKJCholesky(3, 0.5) |
| 145 | + y ~ MvNormal(zeros(3), I) |
| 146 | +end |
| 147 | +init_x, init_y = rand(LKJCholesky(3, 0.5)), rand(MvNormal(zeros(3), I)) |
| 148 | +chn = sample(demo_complex(), MH(), 5; initial_params=(x=init_x, y=init_y)); |
| 149 | +``` |
| 150 | + |
| 151 | +For **multiple-chain sampling**, the `initial_params` keyword argument should be a vector with length equal to the number of chains being sampled. |
| 152 | +Each element of this vector should be the initial parameters for the corresponding chain, as described above. |
| 153 | +Thus, for example, a vector of vectors, or a vector of `NamedTuple`s, can be used. |
| 154 | +If you want to use the same initial parameters for all chains, you can use `fill`: |
| 155 | + |
| 156 | +```{julia} |
| 157 | +initial_params = fill((x=1.0, y=-5.0), 3) |
| 158 | +chn = sample(demo_model(), MH(), MCMCThreads(), 5, 3; initial_params=initial_params) |
| 159 | +chn[:x][1,:], chn[:y][1,:] |
| 160 | +``` |
| 161 | + |
| 162 | +::: {.callout-important} |
| 163 | +## Upcoming changes in Turing v0.41 |
| 164 | + |
| 165 | +In Turing v0.41, instead of providing _initial parameters_, users will have to provide what is conceptually an _initialisation strategy_. |
| 166 | +The keyword argument is still `initial_params`, but the permitted values (for single-chain sampling) will either be: |
| 167 | + |
| 168 | +- `InitFromPrior()`: generate initial parameters by sampling from the prior |
| 169 | +- `InitFromUniform(lower, upper)`: generate initial parameters by sampling uniformly from the given bounds in linked space |
| 170 | +- `InitFromParams(namedtuple_or_dict)`: use the provided initial parameters, supplied either as a `NamedTuple` or a `Dict{<:VarName}` |
| 171 | + |
| 172 | +Initialisation with `Vector` will be fully removed due to its inherent ambiguity. |
| 173 | +Initialisation with a raw `NamedTuple` will still be supported (it will simply be wrapped in `InitFromParams()`); but we expect to remove this eventually, so it will be more future-proof to use `InitFromParams()` directly. |
| 174 | + |
| 175 | +For multiple chains, the same as above applies: the `initial_params` keyword argument should be a vector of initialisation strategies, one per chain. |
| 176 | +::: |
| 177 | + |
| 178 | +## Saving and resuming sampling |
| 179 | + |
| 180 | +By default, MCMC sampling starts from scratch, using the initial parameters provided. |
| 181 | +You can, however, resume sampling from a previous chain. |
| 182 | +This is useful to, for example, perform sampling in batches, or to inspect intermediate results. |
| 183 | + |
| 184 | +Firstly, the previous chain _must_ have been run using the `save_state=true` argument. |
| 185 | + |
| 186 | +```{julia} |
| 187 | +rng = Xoshiro(468) |
| 188 | +
|
| 189 | +chn1 = sample(rng, demo_model(), MH(), 5; save_state=true); |
| 190 | +``` |
| 191 | + |
| 192 | +For `MCMCChains.Chains`, this results in the final sampler state being stored inside the chain metadata. |
| 193 | +You can access it using `DynamicPPL.loadstate`: |
| 194 | + |
| 195 | +```{julia} |
| 196 | +saved_state = DynamicPPL.loadstate(chn1) |
| 197 | +typeof(saved_state) |
| 198 | +``` |
| 199 | + |
| 200 | +::: {.callout-note} |
| 201 | +You can also directly access the saved sampler state with `chn1.info.samplerstate`, but we recommend _not_ using this as it relies on the internal structure of `MCMCChains.Chains`. |
| 202 | +::: |
| 203 | + |
| 204 | +Sampling can then be resumed from this state by providing it as the `initial_state` keyword argument. |
| 205 | + |
| 206 | +```{julia} |
| 207 | +chn2 = sample(demo_model(), MH(), 5; initial_state=saved_state) |
| 208 | +``` |
| 209 | + |
| 210 | +Note that the exact format saved in `chn.info.samplerstate`, and that expected by `initial_state`, depends on the invocation of `sample` used. |
| 211 | +For single-chain sampling, the saved state, and the required initial state, is just a single sampler state. |
| 212 | +For multiple-chain sampling, it is a vector of states, one per chain. |
| 213 | + |
| 214 | +This means that, for example, after sampling a single chain, you could sample three chains that branch off from that final state: |
| 215 | + |
| 216 | +```{julia} |
| 217 | +initial_states = fill(saved_state, 3) |
| 218 | +chn3 = sample(demo_model(), MH(), MCMCThreads(), 5, 3; initial_state=initial_states) |
| 219 | +``` |
| 220 | + |
| 221 | +::: {.callout-note} |
| 222 | +## Initial states versus initial parameters |
| 223 | + |
| 224 | +The `initial_state` and `initial_params` keyword arguments are mutually exclusive. |
| 225 | +If both are provided, `initial_params` will be silently ignored. |
| 226 | + |
| 227 | +```{julia} |
| 228 | +chn2 = sample(rng, demo_model(), MH(), 5; |
| 229 | + initial_state=saved_state, initial_params=(x=0.0, y=0.0) |
| 230 | +) |
| 231 | +chn2[:x][1], chn2[:y][1] |
| 232 | +``` |
| 233 | + |
| 234 | +In general, the saved state will contain a set of parameters (which will be the last parameters in the previous chain). |
| 235 | +However, the saved state not only specifies parameters but also other internal variables required by the sampler. |
| 236 | +For example, the MH state contains a cached log-density of the current parameters, which is later used for calculating the acceptance ratio. |
| 237 | + |
| 238 | +Finally, note that the first sample in the resumed chain will not be the same as the last sample in the previous chain; it will be the sample immediately after that. |
| 239 | + |
| 240 | +```{julia} |
| 241 | +# In general these will not be the same (although it _could_ be if the MH step |
| 242 | +# was rejected -- that is why we seed the sampling in this section). |
| 243 | +chn1[:x][end], chn2[:x][1] |
| 244 | +``` |
| 245 | +::: |
| 246 | + |
| 247 | +## Thinning and warmup |
| 248 | + |
| 249 | +The `num_warmup` and `discard_initial` keyword arguments can be used to control MCMC warmup. |
| 250 | +Both of these are integers, and respectively specify the number of warmup steps to perform, and the number of iterations at the start of the chain to discard. |
| 251 | +Note that the value of `discard_initial` should also include the `num_warmup` steps if you want the warmup steps to be discarded. |
| 252 | + |
| 253 | +Here are some examples of how these two keyword arguments interact: |
| 254 | + |
| 255 | +| `num_warmup=` | `discard_initial=` | Description | |
| 256 | +| -------------- | -------------------- | ---------------------------------------------------------------------------------------------------------------------- | |
| 257 | +| 10 | 10 | Perform 10 warmup steps, discard them; the chain starts from the first non-warmup step | |
| 258 | +| 10 | 15 | Perform 10 warmup steps, discard them and the next 5 steps; the chain starts from the 6th non-warmup step | |
| 259 | +| 10 | 5 | Perform 10 warmup steps, discard the first 5; the chain will contain 5 warmup steps followed by the rest of the chain | |
| 260 | +| 0 | 10 | No warmup steps, discard the first 10 steps; the chain starts from the 11th step | |
| 261 | +| 0 | 0 | No warmup steps, do not discard any steps; the chain starts from the 1st step (corresponding to the initial parameters) | |
| 262 | + |
| 263 | +Each sampler has its own default value for `num_warmup`, but `discard_initial` always defaults to `num_warmup`. |
| 264 | + |
| 265 | +Warmup steps and 'regular' non-warmup steps differ in that warmup steps call `AbstractMCMC.step_warmup`, whereas regular steps call `AbstractMCMC.step`. |
| 266 | +For all the samplers defined in Turing, these two functions are identical; however, they may in general differ for other samplers. |
| 267 | +Please consult the documentation of the respective sampler for details. |
| 268 | + |
| 269 | +A thinning factor can be specified using the `thinning` keyword argument. |
| 270 | +For example, `thinning=10` will keep every tenth sample, discarding the other nine. |
| 271 | + |
| 272 | +Note that thinning is not applied to the first `discard_initial` samples; it is only applied to the remaining samples. |
| 273 | +Thus, for example, if you use `discard_initial=50` and `thinning=10`, the chain will contain samples 51, 61, 71, and so on. |
| 274 | + |
| 275 | +## Performing model checks |
| 276 | + |
| 277 | +DynamicPPL by default performs a number of checks on the model before any sampling is done. |
| 278 | +This catches a number of potential errors in a model, such as having repeated variables (see [the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.DebugUtils.check_model_and_trace) for details). |
| 279 | + |
| 280 | +If you wish to disable this you can pass `check_model=false` to `sample()`. |
| 281 | + |
| 282 | + |
| 283 | +## Callbacks |
| 284 | + |
| 285 | +The `callback` keyword argument can be used to specify a function that is called at the end of each sampler iteration. |
| 286 | +This function should have the signature `callback(rng, model, sampler, sample, iteration::Int; kwargs...)`. |
| 287 | + |
| 288 | +If you are performing multi-chain sampling, `kwargs` will additionally contain `chain_number::Int`, which ranges from 1 to the number of chains. |
| 289 | + |
| 290 | +The [TuringCallbacks.jl package](https://github.com/TuringLang/TuringCallbacks.jl) contains a `TensorBoardCallback`, which can be used to obtain live progress visualisations using [TensorBoard](https://www.tensorflow.org/tensorboard). |
| 291 | + |
| 292 | +## Automatic differentiation |
| 293 | + |
| 294 | +Finally, please note that for samplers which use automatic differentiation (e.g., HMC and NUTS), the AD type should be specified in the sampler constructor itself, rather than as a keyword argument to `sample()`. |
| 295 | + |
| 296 | +In other words, this is correct: |
| 297 | + |
| 298 | +```{julia} |
| 299 | +spl = NUTS(; adtype=AutoForwardDiff()) |
| 300 | +chn = sample(demo_model(), spl, 10); |
| 301 | +``` |
| 302 | + |
| 303 | +and not this: |
| 304 | + |
| 305 | +```julia |
| 306 | +spl = NUTS() |
| 307 | +chn = sample(demo_model(), spl, 10; adtype=AutoForwardDiff()) |
| 308 | +``` |
0 commit comments