Skip to content
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ea69430
update vi interface to match [email protected]
Red-Portal Oct 22, 2025
86ee6dd
revert unintended commit of `runtests.jl`
Red-Portal Oct 22, 2025
3e30e04
Merge branch 'breaking' of github.com:TuringLang/Turing.jl into bump_…
Red-Portal Oct 24, 2025
d870045
update docs for `vi`
Red-Portal Oct 24, 2025
2d928e0
add history entry for `[email protected]`
Red-Portal Oct 24, 2025
5211b37
remove export for removed symbol
Red-Portal Oct 24, 2025
f0d615d
fix formatting
Red-Portal Oct 24, 2025
1b2351f
fix formatting
Red-Portal Oct 24, 2025
2be31b4
tidy tests advi
Red-Portal Oct 24, 2025
e48ae42
fix rename file `advi.jl` to `vi.jl` to reflect naming changes
Red-Portal Oct 24, 2025
44f7762
fix docs
Red-Portal Oct 25, 2025
fd0e928
fix HISTORY.md
Red-Portal Oct 25, 2025
77276bd
fix HISTORY.md
Red-Portal Oct 25, 2025
cb1620c
Merge branch 'main' of github.com:TuringLang/Turing.jl into bump_adva…
Red-Portal Oct 25, 2025
e70ddb4
update history
Red-Portal Oct 25, 2025
115802d
Merge branch 'bump_advancedvi_0.5' of github.com:TuringLang/Turing.jl…
Red-Portal Oct 25, 2025
cdc8b2f
Update README.md for clarity and formatting
yebai Nov 12, 2025
32e70d6
Add linear regression model example to README
yebai Nov 12, 2025
19bf7d6
Add dark/light mode logo support (#2714)
shravanngoswamii Nov 12, 2025
25b5087
Merge branch 'main' of github.com:TuringLang/Turing.jl into bump_adva…
Red-Portal Nov 19, 2025
4c02f7b
bump AdvancedVI version
Red-Portal Nov 19, 2025
6518b82
add exports new algorithms, modify `vi` to operate in unconstrained
Red-Portal Nov 19, 2025
5bd6978
Merge branch 'breaking' of github.com:TuringLang/Turing.jl into bump_…
Red-Portal Nov 19, 2025
874a0b2
add clarification on initializing unconstrained algorithms
Red-Portal Nov 19, 2025
e021eb7
update api
Red-Portal Nov 19, 2025
eec7ef2
run formatter
Red-Portal Nov 19, 2025
b6d8202
run formatter
Red-Portal Nov 19, 2025
b900ab4
run formatter
Red-Portal Nov 19, 2025
e71b07b
run formatter
Red-Portal Nov 19, 2025
c08de12
run formatter
Red-Portal Nov 19, 2025
ae80f1e
run formatter
Red-Portal Nov 19, 2025
73bd309
run formatter
Red-Portal Nov 19, 2025
eaac4c3
run formatter
Red-Portal Nov 19, 2025
757ebb4
revert changes to README
Red-Portal Nov 19, 2025
05ab711
fix wrong use of transformation in vi
Red-Portal Nov 20, 2025
91606b5
change inital value for scale matrices to 0.6*I and update docs
Red-Portal Nov 20, 2025
722153a
run formatter
Red-Portal Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,74 @@ As long as the above functions are defined correctly, Turing will be able to use

The `Turing.Inference.isgibbscomponent(::MySampler)` interface function still exists, but in this version the default has been changed to `true`, so you should not need to overload this.

## **AdvancedVI 0.6**

Turing.jl v0.42 updates `AdvancedVI.jl` compatibility to 0.6 (we skipped the breaking 0.5 update as it does not introduce new features).
`[email protected]` introduces major structural changes including breaking changes to the interface and multiple new features.
The summary of the changes below are the things that affect the end-users of Turing.
For a more comprehensive list of changes, please refer to the [changelogs](https://github.com/TuringLang/AdvancedVI.jl/blob/main/HISTORY.md) in `AdvancedVI`.

### Breaking Changes

A new level of interface for defining different variational algorithms has been introduced in `AdvancedVI` v0.5. As a result, the function `Turing.vi` now receives a keyword argument `algorithm`. The object `algorithm <: AdvancedVI.AbstractVariationalAlgorithm` should now contain all the algorithm-specific configurations. Therefore, keyword arguments of `vi` that were algorithm-specific such as `objective`, `operator`, `averager` and so on, have been moved as fields of the relevant `<: AdvancedVI.AbstractVariationalAlgorithm` structs.
For example,

```julia
vi(model, q, n_iters; objective=RepGradELBO(10), operator=AdvancedVI.ClipScale())
```

is now

```julia
vi(
model,
q,
n_iters;
algorithm=KLMinRepGradDescent(adtype; n_samples=10, operator=AdvancedVI.ClipScale()),
)
```

Similarly,

```julia
vi(
model,
q,
n_iters;
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
)
```

is now

```julia
vi(model, q, n_iters; algorithm=KLMinRepGradProxDescent(adtype; n_samples=10))
```

Additionally,

- The default hyperparameters of `DoG`and `DoWG` have been altered.
- The deprecated `[email protected]`-era interface is now removed.
- `estimate_objective` now returns the value to be minimized by the optimization algorithm. For example, for ELBO maximization algorithms, `estimate_objective` will return the *negative ELBO*. This is breaking change from the previous behavior where the ELBO was returns.
- When using algorithms that expect to operate in unconstrained spaces, the user is now explicitly expected to provide a `Bijectors.TransformedDistribution` wrapping an unconstrained distribution. (Refer to the docstring of `vi`.)

### New Features

`[email protected]` adds numerous new features including the following new VI algorithms:

- `KLMinWassFwdBwd`: Also known as "Wasserstein variational inference," this algorithm minimizes the KL divergence under the Wasserstein-2 metric.
- `KLMinNaturalGradDescent`: This algorithm, also known as "online variational Newton," is the canonical "black-box" natural gradient variational inference algorithm, which minimizes the KL divergence via mirror descent under the KL divergence as the Bregman divergence.
- `KLMinSqrtNaturalGradDescent`: This is a recent variant of `KLMinNaturalGradDescent` that operates in the Cholesky-factor parameterization of Gaussians instead of precision matrices.
- `FisherMinBatchMatch`: This algorithm called "batch-and-match," minimizes the variation of the 2nd order fisher divergence via a proximal point-type algorithm.

Any of the new algorithms above can readily be used by simply swappin the `algorithm` keyword argument of `vi`.
For example, to use batch-and-match:

```julia
vi(model, q, n_iters; algorithm=FisherMinBatchMatch())
```

# 0.41.1

The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Accessors = "0.1"
AdvancedHMC = "0.8.3"
AdvancedMH = "0.8.9"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.6"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Compat = "4.15.0"
Expand Down
41 changes: 25 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
<p align="center"><img src="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/refs/heads/main/assets/logo/turing-logo.svg" alt="Turing.jl logo" width="200" /></p>
<h1 align="center">Turing.jl</h1>
<p align="center"><i>Probabilistic programming and Bayesian inference in Julia</i></p>
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://turinglang.org/assets/logo/turing-logo-dark.svg">
<img src="https://turinglang.org/assets/logo/turing-logo-light.svg" alt="Turing.jl logo" width="300">
</picture>
</p>
<p align="center"><i>Bayesian inference with probabilistic programming</i></p>
<p align="center">
<a href="https://turinglang.org/"><img src="https://img.shields.io/badge/docs-tutorials-blue.svg" alt="Tutorials" /></a>
<a href="https://turinglang.org/Turing.jl/stable"><img src="https://img.shields.io/badge/docs-API-blue.svg" alt="API docs" /></a>
Expand All @@ -9,7 +13,7 @@
<a href="https://github.com/SciML/ColPrac"><img src="https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet" alt="ColPrac: Contributor's Guide on Collaborative Practices for Community Packages" /></a>
</p>

## 🚀 Get started
## Get started

Install Julia (see [the official Julia website](https://julialang.org/install/); you will need at least Julia 1.10 for the latest version of Turing.jl).
Then, launch a Julia REPL and run:
Expand All @@ -23,22 +27,29 @@ You can define models using the `@model` macro, and then perform Markov chain Mo
```julia
julia> using Turing

julia> @model function my_first_model(data)
mean ~ Normal(0, 1)
sd ~ truncated(Cauchy(0, 3); lower=0)
data ~ Normal(mean, sd)
julia> @model function linear_regression(x)
# Priors
α ~ Normal(0, 1)
β ~ Normal(0, 1)
σ² ~ truncated(Cauchy(0, 3); lower=0)

# Likelihood
μ = α .+ β .* x
y ~ MvNormal(μ, σ² * I)
end

julia> model = my_first_model(randn())
julia> x, y = rand(10), rand(10)

julia> chain = sample(model, NUTS(), 1000)
julia> posterior = linear_regression(x) | (; y = y)

julia> chain = sample(posterior, NUTS(), 1000)
```

You can find the main TuringLang documentation at [**https://turinglang.org**](https://turinglang.org), which contains general information about Turing.jl's features, as well as a variety of tutorials with examples of Turing.jl models.

API documentation for Turing.jl is specifically available at [**https://turinglang.org/Turing.jl/stable**](https://turinglang.org/Turing.jl/stable/).

## 🛠️ Contributing
## Contributing

### Issues

Expand All @@ -55,20 +66,20 @@ Breaking releases (minor version) should target the `breaking` branch.

If you have not received any feedback on an issue or PR for a while, please feel free to ping `@TuringLang/maintainers` in a comment.

## 💬 Other channels
## Other channels

The Turing.jl userbase tends to be most active on the [`#turing` channel of Julia Slack](https://julialang.slack.com/archives/CCYDC34A0).
If you do not have an invitation to Julia's Slack, you can get one from [the official Julia website](https://julialang.org/slack/).

There are also often threads on [Julia Discourse](https://discourse.julialang.org) (you can search using, e.g., [the `turing` tag](https://discourse.julialang.org/tag/turing)).

## 🔄 What's changed recently?
## What's changed recently?

We publish a fortnightly newsletter summarising recent updates in the TuringLang ecosystem, which you can view on [our website](https://turinglang.org/news/), [GitHub](https://github.com/TuringLang/Turing.jl/issues/2498), or [Julia Slack](https://julialang.slack.com/archives/CCYDC34A0).

For Turing.jl specifically, you can see a full changelog in [`HISTORY.md`](https://github.com/TuringLang/Turing.jl/blob/main/HISTORY.md) or [our GitHub releases](https://github.com/TuringLang/Turing.jl/releases).

## 🧩 Where does Turing.jl sit in the TuringLang ecosystem?
## Where does Turing.jl sit in the TuringLang ecosystem?

Turing.jl is the main entry point for users, and seeks to provide a unified, convenient interface to all of the functionality in the TuringLang (and broader Julia) ecosystem.

Expand Down Expand Up @@ -125,5 +136,3 @@ month = feb,
```

</details>

You can see the full list of publications that have cited Turing.jl on [Google Scholar](https://scholar.google.com/scholar?cites=11803241473159708991).
6 changes: 6 additions & 0 deletions docs/src/api.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Turing.jl/docs/src/api.md

Lines 111 to 122 in eaac4c3

| Exported symbol | Documentation | Description |
|:---------------------- |:------------------------------------------------- |:---------------------------------------------------------------------------------------- |
| `vi` | [`Turing.vi`](@ref) | Perform variational inference |
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
| `KLMinRepGradDescent` | [`Turing.Variational.KLMinRepGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
| `KLMinRepGradProxDescent` | [`Turing.Variational.KLMinRepGradProxDescent`](@ref) | KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
| `KLMinScoreGradDescent` | [`Turing.Variational.KLMinScoreGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the score gradient |
| `KLMinWassFwdBwd` | [`Turing.Variational.KLMinWassFwdBwd`](@ref) | KL divergence minimization via Wasserstein proximal gradient descent |
| `KLMinNaturalGradDescent` | [`Turing.Variational.KLMinNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent |
| `KLMinSqrtNaturalGradDescent` | [`Turing.Variational.KLMinSqrtNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent in the square-root parameterization |

Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) fo
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
| `KLMinRepGradDescent` | [`Turing.Variational.KLMinRepGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
| `KLMinRepGradProxDescent` | [`Turing.Variational.KLMinRepGradProxDescent`](@ref) | KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
| `KLMinScoreGradDescent` | [`Turing.Variational.KLMinScoreGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the score gradient |
| `KLMinWassFwdBwd` | [`Turing.Variational.KLMinWassFwdBwd`](@ref) | KL divergence minimization via Wasserstein proximal gradient descent |
| `KLMinNaturalGradDescent` | [`Turing.Variational.KLMinNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent |
| `KLMinSqrtNaturalGradDescent` | [`Turing.Variational.KLMinSqrtNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent in the square-root parameterization |

### Automatic differentiation types

Expand Down
7 changes: 6 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,15 @@ export
externalsampler,
# Variational inference - AdvancedVI
vi,
ADVI,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent,
KLMinNaturalGradDescent,
KLMinSqrtNaturalGradDescent,
KLMinWassFwdBwd,
# ADTypes
AutoForwardDiff,
AutoReverseDiff,
Expand Down
130 changes: 78 additions & 52 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@

module Variational

using DynamicPPL
using AdvancedVI:
AdvancedVI,
KLMinRepGradDescent,
KLMinRepGradProxDescent,
KLMinScoreGradDescent,
KLMinWassFwdBwd,
KLMinNaturalGradDescent,
KLMinSqrtNaturalGradDescent

using ADTypes
using Bijectors: Bijectors
using Distributions
using DynamicPPL: DynamicPPL
using LinearAlgebra
using LogDensityProblems
using LogDensityProblems: LogDensityProblems
using Random

import ..Turing: DEFAULT_ADTYPE, PROGRESS

import AdvancedVI
import Bijectors

export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian

include("deprecated.jl")
using ..Turing: DEFAULT_ADTYPE, PROGRESS

export vi,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent,
KLMinWassFwdBwd,
KLMinNaturalGradDescent,
KLMinSqrtNaturalGradDescent

requires_unconstrained_space(::AdvancedVI.AbstractVariationalAlgorithm) = false
requires_unconstrained_space(::AdvancedVI.KLMinRepGradProxDescent) = true
requires_unconstrained_space(::AdvancedVI.KLMinRepGradDescent) = true
requires_unconstrained_space(::AdvancedVI.KLMinWassFwdBwd) = true
requires_unconstrained_space(::AdvancedVI.KLMinNaturalGradDescent) = true
requires_unconstrained_space(::AdvancedVI.KLMinSqrtNaturalGradDescent) = true

"""
q_initialize_scale(
Expand Down Expand Up @@ -62,7 +82,7 @@ function q_initialize_scale(
num_max_trials::Int=10,
reduce_factor::Real=one(eltype(scale)) / 2,
)
prob = LogDensityFunction(model)
prob = DynamicPPL.LogDensityFunction(model)
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
varinfo = DynamicPPL.VarInfo(model)

Expand Down Expand Up @@ -248,76 +268,82 @@ end
"""
vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
max_iter::Int;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(
adtype; n_samples=10
),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)

Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`.
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
This is a thin wrapper around `AdvancedVI.optimize`.

If the chosen variational inference algorithm operates in an unconstrained space, then the provided initial variational approximation `q` must be a `Bijectors.TransformedDistribution` of an unconstrained distribution.
For example, the initialization supplied by `q_meanfield_gaussian`,`q_fullrank_gaussian`, `q_locationscale`.

The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
For other variational families, refer the documentation of `AdvancedVI` to determine the best algorithm and other options.

# Arguments
- `model`: The target `DynamicPPL.Model`.
- `q`: The initial variational approximation.
- `n_iterations`: Number of optimization steps.
- `max_iter`: Maximum number of steps.

# Keyword Arguments
- `objective`: Variational objective to be optimized.
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.
- `algorithm`: Variational inference algorithm.
- `show_progress`: Whether to show the progress bar.
- `optimizer`: Optimization algorithm.
- `averager`: Parameter averaging strategy.
- `operator`: Operator applied after each optimization step.
- `adtype`: Automatic differentiation backend.

See the docs of `AdvancedVI.optimize` for additional keyword arguments.

# Returns
- `q`: Variational distribution formed by the last iterate of the optimization run.
- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`.
- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`.
- `info`: Information generated during the optimization run.
- `q`: Output variational distribution of `algorithm`.
- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`.
- `info`: Information generated while executing `algorithm`.
"""
function vi(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective=AdvancedVI.RepGradELBO(
10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()
max_iter::Int,
args...;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm::AdvancedVI.AbstractVariationalAlgorithm=KLMinRepGradProxDescent(
adtype; n_samples=10
),
unconstrained::Bool=requires_unconstrained_space(algorithm),
show_progress::Bool=PROGRESS[],
optimizer=AdvancedVI.DoWG(),
averager=AdvancedVI.PolynomialAveraging(),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
kwargs...,
)
return AdvancedVI.optimize(
rng,
LogDensityFunction(model),
objective,
q,
n_iterations;
show_progress=show_progress,
adtype,
optimizer,
averager,
operator,
kwargs...,
prob, q, trans = if unconstrained
@assert q isa Bijectors.TransformedDistribution "The algorithm $(algorithm) operates in an unconstrained space. Therefore, the initial variational approximation is expected to be a Bijectors.TransformedDistribution of an unconstrained distribution."
vi = DynamicPPL.ldf_default_varinfo(model, DynamicPPL.getlogjoint_internal)
vi = DynamicPPL.set_transformed!!(vi, true)
prob = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype
)
prob, q.dist, q.transform
else
prob = DynamicPPL.LogDensityFunction(model; adtype)
prob, q, nothing
end
q, info, state = AdvancedVI.optimize(
rng, algorithm, max_iter, prob, q, args...; show_progress=show_progress, kwargs...
)
q = if unconstrained
Bijectors.TransformedDistribution(q, trans)
else
q
end
return q, info, state
end

function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...)
return vi(Random.default_rng(), model, q, n_iterations; kwargs...)
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)
return vi(Random.default_rng(), model, q, max_iter; kwargs...)
end

end
Loading
Loading