|
| 1 | +--- |
| 2 | +title: Stochastic Gradient Samplers |
| 3 | +engine: julia |
| 4 | +--- |
| 5 | + |
| 6 | +```{julia} |
| 7 | +#| echo: false |
| 8 | +#| output: false |
| 9 | +using Pkg; |
| 10 | +Pkg.instantiate(); |
| 11 | +``` |
| 12 | + |
| 13 | +Turing.jl provides stochastic gradient-based MCMC samplers that are designed for large-scale datasets where computing full gradients is computationally expensive. The two main stochastic gradient samplers are **Stochastic Gradient Langevin Dynamics (SGLD)** and **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)**. |
| 14 | + |
| 15 | +**Important**: The current implementation in Turing.jl computes full gradients with added stochastic noise rather than true mini-batch stochastic gradients. These samplers require very careful hyperparameter tuning and are typically most useful for research purposes or when working with streaming data. |
| 16 | + |
| 17 | +## Setup |
| 18 | + |
| 19 | +```{julia} |
| 20 | +using Turing |
| 21 | +using Distributions |
| 22 | +using StatsPlots |
| 23 | +using Random |
| 24 | +using LinearAlgebra |
| 25 | +
|
| 26 | +Random.seed!(123) |
| 27 | +``` |
| 28 | + |
| 29 | +## SGLD (Stochastic Gradient Langevin Dynamics) |
| 30 | + |
| 31 | +SGLD adds properly scaled noise to gradient descent steps to enable MCMC sampling. The key insight is that the right amount of noise transforms optimization into sampling from the posterior distribution. |
| 32 | + |
| 33 | +Let's start with a simple Gaussian model: |
| 34 | + |
| 35 | +```{julia} |
| 36 | +# Generate synthetic data |
| 37 | +true_μ = 2.0 |
| 38 | +true_σ = 1.5 |
| 39 | +N = 100 |
| 40 | +data = rand(Normal(true_μ, true_σ), N) |
| 41 | +
|
| 42 | +# Define a simple Gaussian model |
| 43 | +@model function gaussian_model(x) |
| 44 | + μ ~ Normal(0, 10) |
| 45 | + σ ~ truncated(Normal(0, 5), 0, Inf) |
| 46 | + |
| 47 | + for i in 1:length(x) |
| 48 | + x[i] ~ Normal(μ, σ) |
| 49 | + end |
| 50 | +end |
| 51 | +
|
| 52 | +model = gaussian_model(data) |
| 53 | +``` |
| 54 | + |
| 55 | +SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time: |
| 56 | + |
| 57 | +```{julia} |
| 58 | +# SGLD with polynomial stepsize schedule |
| 59 | +# stepsize(t) = a / (b + t)^γ |
| 60 | +sgld_stepsize = Turing.PolynomialStepsize(0.0001, 10000, 0.55) |
| 61 | +chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 2000) |
| 62 | +
|
| 63 | +summarystats(chain_sgld) |
| 64 | +``` |
| 65 | + |
| 66 | +```{julia} |
| 67 | +#| output: false |
| 68 | +setprogress!(false) |
| 69 | +``` |
| 70 | + |
| 71 | +```{julia} |
| 72 | +plot(chain_sgld) |
| 73 | +``` |
| 74 | + |
| 75 | +## SGHMC (Stochastic Gradient Hamiltonian Monte Carlo) |
| 76 | + |
| 77 | +SGHMC extends HMC to the stochastic gradient setting by incorporating friction to counteract the noise from stochastic gradients: |
| 78 | + |
| 79 | +```{julia} |
| 80 | +# SGHMC with very small learning rate |
| 81 | +chain_sghmc = sample(model, SGHMC(learning_rate=0.00001, momentum_decay=0.1), 2000) |
| 82 | +
|
| 83 | +summarystats(chain_sghmc) |
| 84 | +``` |
| 85 | + |
| 86 | +```{julia} |
| 87 | +plot(chain_sghmc) |
| 88 | +``` |
| 89 | + |
| 90 | +## Comparison with Standard HMC |
| 91 | + |
| 92 | +For comparison, let's sample the same model using standard HMC: |
| 93 | + |
| 94 | +```{julia} |
| 95 | +chain_hmc = sample(model, HMC(0.01, 10), 1000) |
| 96 | +
|
| 97 | +println("True values: μ = ", true_μ, ", σ = ", true_σ) |
| 98 | +summarystats(chain_hmc) |
| 99 | +``` |
| 100 | + |
| 101 | +Compare the trace plots: |
| 102 | + |
| 103 | +```{julia} |
| 104 | +p1 = plot(chain_sgld[:μ], label="SGLD", title="μ parameter traces") |
| 105 | +hline!([true_μ], label="True value", linestyle=:dash, color=:red) |
| 106 | +
|
| 107 | +p2 = plot(chain_sghmc[:μ], label="SGHMC") |
| 108 | +hline!([true_μ], label="True value", linestyle=:dash, color=:red) |
| 109 | +
|
| 110 | +p3 = plot(chain_hmc[:μ], label="HMC") |
| 111 | +hline!([true_μ], label="True value", linestyle=:dash, color=:red) |
| 112 | +
|
| 113 | +plot(p1, p2, p3, layout=(3,1), size=(800,600)) |
| 114 | +``` |
| 115 | + |
| 116 | +## Bayesian Linear Regression Example |
| 117 | + |
| 118 | +Here's a more complex example using Bayesian linear regression: |
| 119 | + |
| 120 | +```{julia} |
| 121 | +# Generate regression data |
| 122 | +n_features = 3 |
| 123 | +n_samples = 100 |
| 124 | +X = randn(n_samples, n_features) |
| 125 | +true_β = [0.5, -1.2, 2.1] |
| 126 | +true_σ_noise = 0.3 |
| 127 | +y = X * true_β + true_σ_noise * randn(n_samples) |
| 128 | +
|
| 129 | +@model function linear_regression(X, y) |
| 130 | + n_features = size(X, 2) |
| 131 | + |
| 132 | + # Priors |
| 133 | + β ~ MvNormal(zeros(n_features), 3 * I) |
| 134 | + σ ~ truncated(Normal(0, 1), 0, Inf) |
| 135 | + |
| 136 | + # Likelihood |
| 137 | + y ~ MvNormal(X * β, σ^2 * I) |
| 138 | +end |
| 139 | +
|
| 140 | +lr_model = linear_regression(X, y) |
| 141 | +``` |
| 142 | + |
| 143 | +Sample using the stochastic gradient methods: |
| 144 | + |
| 145 | +```{julia} |
| 146 | +# Very conservative parameters for stability |
| 147 | +sgld_lr_stepsize = Turing.PolynomialStepsize(0.00005, 10000, 0.55) |
| 148 | +chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 3000) |
| 149 | +
|
| 150 | +chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.00005, momentum_decay=0.1), 3000) |
| 151 | +
|
| 152 | +chain_lr_hmc = sample(lr_model, HMC(0.01, 10), 1000) |
| 153 | +``` |
| 154 | + |
| 155 | +Compare the results: |
| 156 | + |
| 157 | +```{julia} |
| 158 | +println("True β values: ", true_β) |
| 159 | +println("True σ value: ", true_σ_noise) |
| 160 | +println() |
| 161 | +
|
| 162 | +println("SGLD estimates:") |
| 163 | +summarystats(chain_lr_sgld) |
| 164 | +``` |
| 165 | + |
| 166 | +## Automatic Differentiation Backends |
| 167 | + |
| 168 | +Both samplers support different AD backends: |
| 169 | + |
| 170 | +```{julia} |
| 171 | +using ADTypes |
| 172 | +
|
| 173 | +# ForwardDiff (default) - good for few parameters |
| 174 | +sgld_forward = SGLD(stepsize=sgld_stepsize, adtype=AutoForwardDiff()) |
| 175 | +
|
| 176 | +# ReverseDiff - better for many parameters |
| 177 | +sgld_reverse = SGLD(stepsize=sgld_stepsize, adtype=AutoReverseDiff()) |
| 178 | +
|
| 179 | +# Zygote - good for complex models |
| 180 | +sgld_zygote = SGLD(stepsize=sgld_stepsize, adtype=AutoZygote()) |
| 181 | +``` |
| 182 | + |
| 183 | +## Best Practices and Recommendations |
| 184 | + |
| 185 | +### When to Use Stochastic Gradient Samplers |
| 186 | + |
| 187 | +- **Large datasets**: When full gradient computation is prohibitively expensive |
| 188 | +- **Streaming data**: When data arrives continuously |
| 189 | +- **Research**: For studying stochastic gradient MCMC methods |
| 190 | + |
| 191 | +### Critical Hyperparameters |
| 192 | + |
| 193 | +**For SGLD:** |
| 194 | +- Use `PolynomialStepsize` with very small initial values (≤ 0.0001) |
| 195 | +- Larger `b` values in `PolynomialStepsize(a, b, γ)` provide more stability |
| 196 | +- The stepsize decreases as `a / (b + t)^γ` |
| 197 | + |
| 198 | +**For SGHMC:** |
| 199 | +- Use extremely small learning rates (≤ 0.00001) |
| 200 | +- Momentum decay (friction) typically between 0.1-0.5 |
| 201 | +- Higher momentum decay improves stability but slows convergence |
| 202 | + |
| 203 | +### Current Limitations |
| 204 | + |
| 205 | +1. **No mini-batching**: Full gradients are computed despite "stochastic" name |
| 206 | +2. **Hyperparameter sensitivity**: Requires extensive tuning |
| 207 | +3. **Computational overhead**: Often slower than HMC/NUTS for small-medium datasets |
| 208 | +4. **Convergence**: Typically requires longer chains |
| 209 | + |
| 210 | +### General Recommendations |
| 211 | + |
| 212 | +- **Start conservatively**: Use very small step sizes initially |
| 213 | +- **Monitor convergence**: Check trace plots and diagnostics carefully |
| 214 | +- **Compare with HMC/NUTS**: Validate results when possible |
| 215 | +- **Consider alternatives**: For most applications, HMC or NUTS will be more efficient |
| 216 | + |
| 217 | +## Summary |
| 218 | + |
| 219 | +Stochastic gradient samplers in Turing.jl provide an interface to gradient-based MCMC methods with added stochasticity. While designed for large-scale problems, the current implementation uses full gradients, making them primarily useful for research or specialized applications. For most practical Bayesian inference tasks, standard samplers like HMC or NUTS will be more efficient and easier to tune. |
0 commit comments