-
Notifications
You must be signed in to change notification settings - Fork 104
Add documentation for Stochastic Gradient Samplers #629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
6311202
Add documentation for Stochastic Gradient Samplers
4b3f7d1
Implement code changes to enhance functionality and improve performance
18ad4b0
Merge branch 'main' into add-stoc-docs
AoifeHughes 64c3870
Add Stochastic Gradient Samplers documentation and enhance existing c…
b8a2b6a
bump versions
869f9d0
mani regen
3ae881b
Merge branch 'main' into add-stoc-docs
AoifeHughes b32f6d5
fix typo
penelopeysm c0a6f03
Add Stochastic Gradient Samplers documentation and enhance existing c…
855a012
bump versions
3968c13
mani regen
4cded52
Color-Theme update to match main site for consistency (#613)
shravanngoswamii f562afa
fix urls for open graph (#630)
shravanngoswamii dbf2a97
Remove MicroCanonicalHMC.jl and update external sampler docs (#628)
penelopeysm 5a0f806
Added redirects (#632)
shravanngoswamii 8793cfb
Update README.md
shravanngoswamii 9dce424
Merge branch 'add-stoc-docs' of https://github.com/TuringLang/docs in…
3e3aa73
Merge branch 'main' into add-stoc-docs
AoifeHughes c64f29f
updated messages and notices
9b269a7
Merge branch 'main' into add-stoc-docs
AoifeHughes e083a4b
bumped chain length
339da9d
tried to tweak sampling and updated explainations
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
--- | ||
title: Stochastic Gradient Samplers | ||
engine: julia | ||
--- | ||
|
||
```{julia} | ||
#| echo: false | ||
#| output: false | ||
using Pkg; | ||
Pkg.instantiate(); | ||
``` | ||
|
||
Turing.jl provides stochastic gradient-based MCMC samplers: **Stochastic Gradient Langevin Dynamics (SGLD)** and **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)**. | ||
|
||
::: {.callout-warning} | ||
## Research-Grade Implementation | ||
These samplers are **primarily intended for research purposes** and require significant expertise to use effectively. For production use and most practical applications, we strongly recommend using HMC or NUTS instead, which are more robust and efficient. | ||
::: | ||
|
||
## Current Capabilities | ||
|
||
The current implementation in Turing.jl is primarily useful for: | ||
- **Research purposes**: Studying stochastic gradient MCMC methods | ||
- **Educational purposes**: Understanding stochastic gradient MCMC algorithms | ||
- **Streaming data**: When data arrives continuously (with careful tuning) | ||
- **Experimental applications**: Testing stochastic sampling approaches | ||
|
||
**Important**: The current implementation computes full gradients with added stochastic noise rather than true mini-batch stochastic gradients. This means these samplers don't currently provide the computational benefits typically associated with stochastic gradient methods for large datasets. They require very careful hyperparameter tuning and often perform slower than standard samplers like HMC or NUTS for most practical applications. | ||
|
||
**Future Development**: These stochastic gradient samplers are being migrated to [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl) for better maintenance and development. Once migration is complete, Turing.jl will support AbstractMCMC-compatible algorithms, and users requiring research-grade stochastic gradient algorithms will be directed to AdvancedHMC. | ||
|
||
## Setup | ||
|
||
```{julia} | ||
using Turing | ||
using Distributions | ||
using StatsPlots | ||
using Random | ||
using LinearAlgebra | ||
|
||
Random.seed!(123) | ||
|
||
# Disable progress bars for cleaner output | ||
Turing.setprogress!(false) | ||
``` | ||
|
||
## SGLD (Stochastic Gradient Langevin Dynamics) | ||
|
||
SGLD adds properly scaled noise to gradient descent steps to enable MCMC sampling. The key insight is that the right amount of noise transforms optimization into sampling from the posterior distribution. | ||
|
||
Let's start with a simple Gaussian model: | ||
|
||
```{julia} | ||
# Generate synthetic data | ||
true_μ = 2.0 | ||
true_σ = 1.5 | ||
N = 100 | ||
data = rand(Normal(true_μ, true_σ), N) | ||
|
||
# Define a simple Gaussian model | ||
@model function gaussian_model(x) | ||
μ ~ Normal(0, 10) | ||
σ ~ truncated(Normal(0, 5); lower=0) | ||
|
||
for i in 1:length(x) | ||
x[i] ~ Normal(μ, σ) | ||
end | ||
end | ||
|
||
model = gaussian_model(data) | ||
``` | ||
|
||
SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time. Note: Currently, `PolynomialStepsize` is the primary stepsize schedule available in Turing for SGLD. | ||
|
||
**Important Note on Convergence**: The examples below use longer chains (10,000-15,000 samples) with the first half discarded as burn-in to ensure proper convergence. This is typical for stochastic gradient samplers, which require more samples than standard HMC/NUTS to achieve reliable results: | ||
|
||
```{julia} | ||
# SGLD with polynomial stepsize schedule | ||
# stepsize(t) = a / (b + t)^γ | ||
# Using smaller step size and longer chain for better convergence | ||
sgld_stepsize = Turing.PolynomialStepsize(0.00005, 20000, 0.55) | ||
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 10000) | ||
|
||
# Note: We use a longer chain (10000 samples) to ensure convergence | ||
# The first half can be considered burn-in | ||
summarystats(chain_sgld[5001:end]) | ||
``` | ||
|
||
|
||
```{julia} | ||
# Plot the second half of the chain to show converged behavior | ||
plot(chain_sgld[5001:end]) | ||
``` | ||
|
||
## SGHMC (Stochastic Gradient Hamiltonian Monte Carlo) | ||
|
||
SGHMC extends HMC to the stochastic gradient setting by incorporating friction to counteract the noise from stochastic gradients: | ||
|
||
```{julia} | ||
# SGHMC with very small learning rate and longer chain | ||
chain_sghmc = sample(model, SGHMC(learning_rate=0.000005, momentum_decay=0.2), 10000) | ||
|
||
# Using the second half of the chain after burn-in | ||
summarystats(chain_sghmc[5001:end]) | ||
``` | ||
|
||
```{julia} | ||
# Plot the second half of the chain to show converged behavior | ||
plot(chain_sghmc[5001:end]) | ||
``` | ||
|
||
## Comparison with Standard HMC | ||
|
||
For comparison, let's sample the same model using standard HMC: | ||
|
||
```{julia} | ||
# Note: Using step size 0.05 instead of 0.01 for better exploration | ||
# Step size 0.01 can be too small for this simple model, leading to poor mixing | ||
chain_hmc = sample(model, HMC(0.05, 10), 1000) | ||
|
||
println("True values: μ = ", true_μ, ", σ = ", true_σ) | ||
summarystats(chain_hmc) | ||
``` | ||
|
||
Compare the trace plots to see how the different samplers explore the posterior: | ||
|
||
```{julia} | ||
# Compare converged portions of the chains | ||
# Note: Due to poor convergence of stochastic gradient methods, we show their | ||
# full chains to illustrate the mixing issues | ||
p1 = plot(chain_sgld[:μ], label="SGLD (full chain)", title="μ parameter traces", | ||
ylabel="SGLD values") | ||
hline!([true_μ], label="True value", linestyle=:dash, color=:red) | ||
|
||
p2 = plot(chain_sghmc[:μ], label="SGHMC (full chain)", ylabel="SGHMC values") | ||
hline!([true_μ], label="True value", linestyle=:dash, color=:red) | ||
|
||
p3 = plot(chain_hmc[:μ], label="HMC", ylabel="HMC values") | ||
hline!([true_μ], label="True value", linestyle=:dash, color=:red) | ||
|
||
plot(p1, p2, p3, layout=(3,1), size=(800,600)) | ||
``` | ||
|
||
## Understanding the Results | ||
|
||
When examining the summary statistics, pay attention to these key diagnostics: | ||
|
||
- **ESS (Effective Sample Size)**: Should be > 100 for reliable inference. SGLD/SGHMC often show ESS < 50, indicating poor mixing | ||
- **R-hat**: Should be < 1.01. Values > 1.1 indicate convergence problems | ||
- **MCSE (Monte Carlo Standard Error)**: Should be small relative to posterior standard deviation | ||
|
||
The trace plots clearly illustrate the fundamental issues with stochastic gradient samplers: | ||
- **SGLD** shows extremely poor mixing with tiny, noisy steps that barely explore the parameter space | ||
- **SGHMC** exhibits similar problems with minimal exploration despite the momentum term | ||
- **HMC** demonstrates proper exploration of the posterior with good mixing and efficient sampling | ||
|
||
This visual comparison explains why the ESS values are so low for the stochastic gradient methods - they are effectively stuck in small regions of the parameter space, taking tiny steps that don't contribute to effective exploration. | ||
|
||
## Bayesian Linear Regression Example | ||
|
||
Here's a more complex example using Bayesian linear regression: | ||
|
||
```{julia} | ||
# Generate regression data | ||
n_features = 3 | ||
n_samples = 100 | ||
X = randn(n_samples, n_features) | ||
true_β = [0.5, -1.2, 2.1] | ||
true_σ_noise = 0.3 | ||
y = X * true_β + true_σ_noise * randn(n_samples) | ||
|
||
@model function linear_regression(X, y) | ||
n_features = size(X, 2) | ||
|
||
# Priors | ||
β ~ MvNormal(zeros(n_features), 3 * I) | ||
σ ~ truncated(Normal(0, 1); lower=0) | ||
|
||
# Likelihood | ||
y ~ MvNormal(X * β, σ^2 * I) | ||
end | ||
|
||
lr_model = linear_regression(X, y) | ||
``` | ||
|
||
Sample using the stochastic gradient methods: | ||
|
||
```{julia} | ||
# Very conservative parameters for stability with longer chains | ||
sgld_lr_stepsize = Turing.PolynomialStepsize(0.00002, 30000, 0.55) | ||
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 15000) | ||
|
||
chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.000002, momentum_decay=0.3), 15000) | ||
|
||
chain_lr_hmc = sample(lr_model, HMC(0.01, 10), 1000) | ||
``` | ||
|
||
Compare the results to evaluate the performance of stochastic gradient samplers on a more complex model: | ||
|
||
```{julia} | ||
println("True β values: ", true_β) | ||
println("True σ value: ", true_σ_noise) | ||
println() | ||
|
||
println("SGLD estimates (after burn-in):") | ||
summarystats(chain_lr_sgld[7501:end]) | ||
``` | ||
|
||
The linear regression example demonstrates that stochastic gradient samplers can recover the true parameters, but: | ||
- They require significantly longer chains (15000 vs 1000 for HMC) | ||
- We discard the first half as burn-in to ensure convergence | ||
- The estimates may still have higher variance than HMC | ||
- Convergence diagnostics should be carefully examined before trusting the results | ||
|
||
## Automatic Differentiation Backends | ||
|
||
Both samplers support different AD backends. For more information about automatic differentiation in Turing, see the [Automatic Differentiation](../automatic-differentiation/) documentation. | ||
|
||
```{julia} | ||
using ADTypes | ||
|
||
# ForwardDiff (default) - good for few parameters | ||
sgld_forward = SGLD(stepsize=sgld_stepsize, adtype=AutoForwardDiff()) | ||
|
||
# ReverseDiff - better for many parameters | ||
sgld_reverse = SGLD(stepsize=sgld_stepsize, adtype=AutoReverseDiff()) | ||
|
||
# Zygote - good for complex models | ||
sgld_zygote = SGLD(stepsize=sgld_stepsize, adtype=AutoZygote()) | ||
``` | ||
|
||
## Best Practices and Recommendations | ||
|
||
### When to Consider Stochastic Gradient Samplers | ||
|
||
- **Streaming data**: When data arrives continuously and you need online inference | ||
- **Research**: For studying stochastic gradient MCMC methods | ||
- **Educational purposes**: For understanding stochastic gradient MCMC algorithms | ||
|
||
### Critical Hyperparameters | ||
|
||
**For SGLD:** | ||
- Use `PolynomialStepsize` with very small initial values (≤ 0.0001) | ||
- Larger `b` values in `PolynomialStepsize(a, b, γ)` provide more stability | ||
- The stepsize decreases as `a / (b + t)^γ` | ||
- **Recommended starting point**: `PolynomialStepsize(0.0001, 10000, 0.55)` | ||
- **For unstable models**: Reduce `a` to 0.00001 or increase `b` to 50000 | ||
|
||
**For SGHMC:** | ||
- Use extremely small learning rates (≤ 0.00001) | ||
- Momentum decay (friction) typically between 0.1-0.5 | ||
- Higher momentum decay improves stability but slows convergence | ||
- **Recommended starting point**: `learning_rate=0.00001, momentum_decay=0.1` | ||
- **For high-dimensional problems**: Increase momentum_decay to 0.3-0.5 | ||
|
||
**For HMC (comparison baseline):** | ||
- Start with step size 0.05-0.1 for simple models (2-3 parameters) | ||
- For complex models (>5 parameters), try step size 0.01-0.05 | ||
- If you see poor mixing (low ESS), try increasing step size | ||
- If you see divergences or numerical issues, reduce step size | ||
|
||
**Tuning Strategy:** | ||
1. **First establish HMC baseline**: Get HMC working with good ESS (>500) and R-hat < 1.01 | ||
2. Start with recommended stochastic gradient values and run a short chain (500-1000 samples) | ||
3. If chains diverge or parameters explode, reduce step size by factor of 10 | ||
4. If mixing is too slow, carefully increase step size by factor of 2 | ||
5. Always validate against HMC/NUTS results when possible | ||
|
||
### Current Limitations | ||
|
||
1. **No mini-batching**: Full gradients are computed despite "stochastic" name | ||
2. **Hyperparameter sensitivity**: Requires extensive tuning | ||
3. **Computational overhead**: Often slower than HMC/NUTS for small-medium datasets | ||
4. **Convergence**: Typically requires longer chains | ||
|
||
### Convergence Diagnostics | ||
|
||
Due to the high variance and slow convergence of stochastic gradient samplers, careful diagnostics are essential: | ||
|
||
- **Visual inspection**: Always check trace plots for all parameters | ||
- **Effective sample size (ESS)**: Expect lower ESS than HMC/NUTS | ||
- **R-hat values**: Should be < 1.01 for all parameters | ||
- **Long chains**: Often need 5,000-10,000+ samples for convergence | ||
- **Multiple chains**: Run multiple chains with different initializations to verify convergence | ||
|
||
### General Recommendations | ||
|
||
- **Start conservatively**: Use very small step sizes initially | ||
- **Monitor convergence**: Check trace plots and diagnostics carefully | ||
- **Increase samples if needed**: Don't hesitate to use 10,000+ samples if convergence is poor | ||
- **Compare with HMC/NUTS**: Validate results when possible | ||
- **Consider alternatives**: For most applications, HMC or NUTS will be more efficient | ||
|
||
## Summary | ||
|
||
Stochastic gradient samplers in Turing.jl provide an interface to gradient-based MCMC methods with added stochasticity. While designed for large-scale problems, the current implementation uses full gradients, making them primarily useful for research or specialized applications. For most practical Bayesian inference tasks, standard samplers like HMC or NUTS will be more efficient and easier to tune. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a general comment, not related to the line it's attached to: The navigation bar on the left needs a new link to this page, I think currently there's no way to navigate to it without knowing the URL.