Skip to content

Commit c0a6f03

Browse files
author
AoifeHughes
committed
Add Stochastic Gradient Samplers documentation and enhance existing content
1 parent b32f6d5 commit c0a6f03

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

_quarto.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ website:
5757
collapse-level: 1
5858
contents:
5959
- usage/automatic-differentiation/index.qmd
60+
- usage/stochastic-gradient-samplers/index.qmd
6061
- usage/submodels/index.qmd
6162
- usage/custom-distribution/index.qmd
6263
- usage/probability-interface/index.qmd

usage/stochastic-gradient-samplers/index.qmd

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@ using Pkg;
1010
Pkg.instantiate();
1111
```
1212

13-
Turing.jl provides stochastic gradient-based MCMC samplers that are designed for large-scale datasets where computing full gradients is computationally expensive. The two main stochastic gradient samplers are **Stochastic Gradient Langevin Dynamics (SGLD)** and **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)**.
13+
Turing.jl provides stochastic gradient-based MCMC samplers: **Stochastic Gradient Langevin Dynamics (SGLD)** and **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)**.
1414

15-
**Important**: The current implementation in Turing.jl computes full gradients with added stochastic noise rather than true mini-batch stochastic gradients. These samplers require very careful hyperparameter tuning and are typically most useful for research purposes or when working with streaming data.
15+
## Current Capabilities
16+
17+
The current implementation in Turing.jl is primarily useful for:
18+
- **Research purposes**: Studying stochastic gradient MCMC methods
19+
- **Streaming data**: When data arrives continuously
20+
- **Experimental applications**: Testing stochastic sampling approaches
21+
22+
**Important**: The current implementation computes full gradients with added stochastic noise rather than true mini-batch stochastic gradients. This means these samplers don't currently provide the computational benefits typically associated with stochastic gradient methods for large datasets. They require very careful hyperparameter tuning and often perform slower than standard samplers like HMC or NUTS for most practical applications.
1623

1724
## Setup
1825

@@ -24,6 +31,9 @@ using Random
2431
using LinearAlgebra
2532
2633
Random.seed!(123)
34+
35+
# Disable progress bars for cleaner output
36+
Turing.setprogress!(false)
2737
```
2838

2939
## SGLD (Stochastic Gradient Langevin Dynamics)
@@ -42,7 +52,7 @@ data = rand(Normal(true_μ, true_σ), N)
4252
# Define a simple Gaussian model
4353
@model function gaussian_model(x)
4454
μ ~ Normal(0, 10)
45-
σ ~ truncated(Normal(0, 5), 0, Inf)
55+
σ ~ truncated(Normal(0, 5); lower=0)
4656
4757
for i in 1:length(x)
4858
x[i] ~ Normal(μ, σ)
@@ -52,21 +62,17 @@ end
5262
model = gaussian_model(data)
5363
```
5464

55-
SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time:
65+
SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time. Note: Currently, `PolynomialStepsize` is the primary stepsize schedule available in Turing for SGLD:
5666

5767
```{julia}
5868
# SGLD with polynomial stepsize schedule
5969
# stepsize(t) = a / (b + t)^γ
6070
sgld_stepsize = Turing.PolynomialStepsize(0.0001, 10000, 0.55)
61-
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 2000)
71+
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 5000)
6272
6373
summarystats(chain_sgld)
6474
```
6575

66-
```{julia}
67-
#| output: false
68-
setprogress!(false)
69-
```
7076

7177
```{julia}
7278
plot(chain_sgld)
@@ -78,7 +84,7 @@ SGHMC extends HMC to the stochastic gradient setting by incorporating friction t
7884

7985
```{julia}
8086
# SGHMC with very small learning rate
81-
chain_sghmc = sample(model, SGHMC(learning_rate=0.00001, momentum_decay=0.1), 2000)
87+
chain_sghmc = sample(model, SGHMC(learning_rate=0.00001, momentum_decay=0.1), 5000)
8288
8389
summarystats(chain_sghmc)
8490
```
@@ -98,7 +104,7 @@ println("True values: μ = ", true_μ, ", σ = ", true_σ)
98104
summarystats(chain_hmc)
99105
```
100106

101-
Compare the trace plots:
107+
Compare the trace plots to see how the different samplers explore the posterior:
102108

103109
```{julia}
104110
p1 = plot(chain_sgld[:μ], label="SGLD", title="μ parameter traces")
@@ -113,6 +119,11 @@ hline!([true_μ], label="True value", linestyle=:dash, color=:red)
113119
plot(p1, p2, p3, layout=(3,1), size=(800,600))
114120
```
115121

122+
The comparison shows that:
123+
- **SGLD** exhibits slower convergence and higher variance due to the injected noise, requiring longer chains to achieve stable estimates
124+
- **SGHMC** shows slightly better mixing than SGLD due to the momentum term, but still requires careful tuning
125+
- **HMC** converges quickly and efficiently explores the posterior, demonstrating why it's preferred for small to medium-sized problems
126+
116127
## Bayesian Linear Regression Example
117128

118129
Here's a more complex example using Bayesian linear regression:
@@ -131,7 +142,7 @@ y = X * true_β + true_σ_noise * randn(n_samples)
131142
132143
# Priors
133144
β ~ MvNormal(zeros(n_features), 3 * I)
134-
σ ~ truncated(Normal(0, 1), 0, Inf)
145+
σ ~ truncated(Normal(0, 1); lower=0)
135146
136147
# Likelihood
137148
y ~ MvNormal(X * β, σ^2 * I)
@@ -145,14 +156,14 @@ Sample using the stochastic gradient methods:
145156
```{julia}
146157
# Very conservative parameters for stability
147158
sgld_lr_stepsize = Turing.PolynomialStepsize(0.00005, 10000, 0.55)
148-
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 3000)
159+
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 5000)
149160
150-
chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.00005, momentum_decay=0.1), 3000)
161+
chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.00005, momentum_decay=0.1), 5000)
151162
152163
chain_lr_hmc = sample(lr_model, HMC(0.01, 10), 1000)
153164
```
154165

155-
Compare the results:
166+
Compare the results to evaluate the performance of stochastic gradient samplers on a more complex model:
156167

157168
```{julia}
158169
println("True β values: ", true_β)
@@ -163,9 +174,14 @@ println("SGLD estimates:")
163174
summarystats(chain_lr_sgld)
164175
```
165176

177+
The linear regression example demonstrates that stochastic gradient samplers can recover the true parameters, but:
178+
- They require significantly longer chains (5000 vs 1000 for HMC)
179+
- The estimates may have higher variance
180+
- Convergence diagnostics should be carefully examined before trusting the results
181+
166182
## Automatic Differentiation Backends
167183

168-
Both samplers support different AD backends:
184+
Both samplers support different AD backends. For more information about automatic differentiation in Turing, see the [Automatic Differentiation](../automatic-differentiation/) documentation.
169185

170186
```{julia}
171187
using ADTypes
@@ -182,11 +198,11 @@ sgld_zygote = SGLD(stepsize=sgld_stepsize, adtype=AutoZygote())
182198

183199
## Best Practices and Recommendations
184200

185-
### When to Use Stochastic Gradient Samplers
201+
### When to Consider Stochastic Gradient Samplers
186202

187-
- **Large datasets**: When full gradient computation is prohibitively expensive
188-
- **Streaming data**: When data arrives continuously
203+
- **Streaming data**: When data arrives continuously and you need online inference
189204
- **Research**: For studying stochastic gradient MCMC methods
205+
- **Educational purposes**: For understanding stochastic gradient MCMC algorithms
190206

191207
### Critical Hyperparameters
192208

0 commit comments

Comments
 (0)