Skip to content

Commit 112c807

Browse files
authored
Improvements (#15)
* Improvements * Better math * Docstring * Probability distribution pullback
1 parent 87382ee commit 112c807

File tree

9 files changed

+360
-266
lines changed

9 files changed

+360
-266
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1011
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1617

1718
[compat]
1819
ChainRulesCore = "1.23"
20+
Compat = "3,4"
1921
DensityInterface = "0.4"
2022
Distributions = "0.25"
2123
DocStringExtensions = "0.9"

docs/src/background.md

Lines changed: 125 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,166 @@
11
# Background
22

3-
Consider a function ``f: \mathbb{R}^n \to \mathbb{R}^m`` and a parametric probability distribution ``p(\theta)`` on the input space ``\mathbb{R}^n``.
4-
Given a random variable ``X \sim p(\theta)``, we want to differentiate the following expectation with respect to ``\theta``:
3+
Most of the math below is taken from [mohamedMonteCarloGradient2020](@citet).
4+
5+
Consider a function $f: \mathbb{R}^n \to \mathbb{R}^m$, a parameter $\theta \in \mathbb{R}^d$ and a parametric probability distribution $p(\theta)$ on the input space.
6+
Given a random variable $X \sim p(\theta)$, we want to differentiate the expectation of $Y = f(X)$ with respect to $\theta$:
7+
8+
$$
9+
E(\theta) = \mathbb{E}[f(X)] = \int f(x) ~ p(x | \theta) ~\mathrm{d} x
10+
$$
11+
12+
Usually this is approximated with Monte-Carlo sampling: let $x_1, \dots, x_S \sim p(\theta)$ be i.i.d., we have the estimator
13+
14+
$$
15+
E(\theta) \simeq \frac{1}{S} \sum_{s=1}^S f(x_s)
16+
$$
517

6-
```math
7-
F(\theta) = \mathbb{E}_{p(\theta)}[f(X)]
8-
```
18+
## Autodiff
919

10-
Since ``F`` is a vector-to-vector function, the key quantity we want to compute is its Jacobian matrix ``\partial F(\theta) \in \mathbb{R}^{m \times n}``.
11-
However, to implement automatic differentiation, we only need vector-Jacobian products (VJPs) ``\partial F(\theta)^\top v`` with ``v \in \mathbb{R}^m``, also called pullbacks.
20+
Since $E$ is a vector-to-vector function, the key quantity we want to compute is its Jacobian matrix $\partial E(\theta) \in \mathbb{R}^{m \times n}$:
21+
22+
$$
23+
\partial E(\theta) = \int y ~ \nabla_\theta q(y | \theta)^\top ~ \mathrm{d} y = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~\mathrm{d} x
24+
$$
25+
26+
However, to implement automatic differentiation, we only need the vector-Jacobian product (VJP) $\partial E(\theta)^\top \bar{y}$ with an output cotangent $\bar{y} \in \mathbb{R}^m$.
1227
See the book by [blondelElementsDifferentiableProgramming2024](@citet) to know more.
1328

14-
Most of the math below is taken from [mohamedMonteCarloGradient2020](@citet).
29+
Our goal is to rephrase this VJP as an expectation, so that we may approximate it with Monte-Carlo sampling as well.
1530

1631
## REINFORCE
1732

18-
### Principle
33+
Implemented by [`Reinforce`](@ref).
1934

20-
The REINFORCE estimator is derived with the help of the identity ``\nabla \log u = \nabla u / u``:
35+
### Score function
2136

22-
```math
37+
The REINFORCE estimator is derived with the help of the identity $\nabla \log u = \nabla u / u$:
38+
39+
$$
2340
\begin{aligned}
24-
F(\theta + \varepsilon)
25-
& = \int f(x) ~ p(x, \theta + \varepsilon) ~ \mathrm{d}x \\
26-
& \approx \int f(x) ~ \left(p(x, \theta) + \nabla_\theta p(x, \theta)^\top \varepsilon\right) ~ \mathrm{d}x \\
27-
& = \int f(x) ~ \left(p(x, \theta) + p(x, \theta) \nabla_\theta \log p(x, \theta)^\top \varepsilon\right) ~ \mathrm{d}x \\
28-
& = F(\theta) + \left(\int f(x) ~ p(x, \theta) \nabla_\theta \log p(x, \theta)^\top ~ \mathrm{d}x\right) \varepsilon \\
29-
& = F(\theta) + \mathbb{E}_{p(\theta)} \left[f(X) \nabla_\theta \log p(X, \theta)^\top\right] ~ \varepsilon \\
41+
\partial E(\theta)
42+
& = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~ \mathrm{d}x \\
43+
& = \int f(x) ~ p(x | \theta) \nabla_\theta \log p(x | \theta)^\top ~ \mathrm{d}x \\
44+
& = \mathbb{E} \left[f(X) \nabla_\theta \log p(X | \theta)^\top\right] \\
3045
\end{aligned}
31-
```
46+
$$
3247

33-
We thus identify the Jacobian matrix:
48+
And the VJP:
3449

35-
```math
36-
\partial F(\theta) = \mathbb{E}_{p(\theta)} \left[f(X) \nabla_\theta \log p(X, \theta)^\top\right]
37-
```
50+
$$
51+
\partial E(\theta)^\top \bar{y} = \mathbb{E} \left[f(X)^\top \bar{y} ~\nabla_\theta \log p(X | \theta)\right]
52+
$$
3853

39-
And the vector-Jacobian product:
54+
Our Monte-Carlo approximation will therefore be:
4055

41-
```math
42-
\partial F(\theta)^\top v = \mathbb{E}_{p(\theta)} \left[(f(X)^\top v) \nabla_\theta \log p(X, \theta)\right]
43-
```
56+
$$
57+
\partial E(\theta)^\top \bar{y} \simeq \frac{1}{S} \sum_{s=1}^S f(x_s)^\top \bar{y} ~ \nabla_\theta \log p(x_s | \theta)
58+
$$
4459

4560
### Variance reduction
4661

47-
Since the REINFORCE estimator has high variance, it can be reduced by using a baseline [koolBuyREINFORCESamples2022](@citep).
48-
For $k > 1$ Monte-Carlo samples, we have
62+
The REINFORCE estimator has high variance, but its variance is reduced by subtracting a so-called baseline $b = \frac{1}{S} \sum_{s=1}^S f(x_s)$ [koolBuyREINFORCESamples2022](@citep).
63+
64+
For $S > 1$ Monte-Carlo samples, we have
4965

50-
```math
66+
$$
5167
\begin{aligned}
52-
\partial F(\theta) &\simeq \frac{1}{k}\sum_{i=1}^k f(x_k) \nabla_\theta\log p(x_k, \theta)^\top\\
53-
& \simeq \frac{1}{k}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k - 1}\sum_{j\neq i} f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top\\
54-
& = \frac{1}{k - 1}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top
68+
\partial E(\theta)^\top \bar{y}
69+
& \simeq \frac{1}{S} \sum_{s=1}^S \left(f(x_s) - \frac{1}{S - 1}\sum_{j\neq i} f(x_j) \right)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)\\
70+
& = \frac{1}{S - 1}\sum_{s=1}^S (f(x_s) - b)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)
5571
\end{aligned}
56-
```
57-
58-
This gives the following vector-Jacobian product:
59-
60-
```math
61-
\partial F(\theta)^\top v \simeq \frac{1}{k - 1}\sum_{i=1}^k \left(\left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right)^\top v\right) \nabla_\theta\log p(x_i, \theta)
62-
```
72+
$$
6373

6474
## Reparametrization
6575

76+
Implemented by [`Reparametrization`](@ref).
77+
6678
### Trick
6779

68-
The reparametrization trick assumes that we can rewrite the random variable ``X \sim p(\theta)`` as ``X = g(Z, \theta)``, where ``Z \sim q`` is another random variable whose distribution does not depend on ``\theta``.
80+
The reparametrization trick assumes that we can rewrite the random variable $X \sim p(\theta)$ as $X = g_\theta(Z)$, where $Z \sim r$ is another random variable whose distribution $r$ does not depend on $\theta$.
6981

70-
```math
71-
\begin{aligned}
72-
F(\theta + \varepsilon)
73-
& = \int f(g(z, \theta + \varepsilon)) ~ q(z) ~ \mathrm{d}z \\
74-
& \approx \int f\left(g(z, \theta) + \partial_\theta g(z, \theta) ~ \varepsilon\right) ~ q(z) ~ \mathrm{d}z \\
75-
& \approx F(\theta) + \int \partial f(g(z, \theta)) ~ \partial_\theta g(z, \theta) ~ \varepsilon ~ q(z) ~ \mathrm{d}z \\
76-
& \approx F(\theta) + \mathbb{E}_q \left[ \partial f(g(Z, \theta)) ~ \partial_\theta g(Z, \theta) \right] ~ \varepsilon \\
77-
\end{aligned}
78-
```
82+
The expectation is rewritten with $h = f \circ g$:
7983

80-
If we denote ``h(z, \theta) = f(g(z, \theta))``, then we identify the Jacobian matrix:
84+
$$
85+
E(\theta) = \mathbb{E}\left[ f(g_\theta(Z)) \right] = \mathbb{E}\left[ h_\theta(Z) \right]
86+
$$
8187

82-
```math
83-
\partial F(\theta) = \mathbb{E}_q \left[ \partial_\theta h(Z, \theta) \right]
84-
```
88+
And we can directly differentiate through the expectation:
8589

86-
And the vector-Jacobian product:
90+
$$
91+
\partial E(\theta) = \mathbb{E} \left[ \partial_\theta h_\theta(Z) \right]
92+
$$
8793

88-
```math
89-
\partial F(\theta)^\top v = \mathbb{E}_q \left[ \partial_\theta h(Z, \theta)^\top v \right]
90-
```
94+
This yields the VJP:
95+
96+
$$
97+
\partial E(\theta)^\top \bar{y} = \mathbb{E} \left[ \partial_\theta h_\theta(Z)^\top \bar{y} \right]
98+
$$
99+
100+
We can use a Monte-Carlo approximation with i.i.d. samples $z_1, \dots, z_S \sim r$:
101+
102+
$$
103+
\partial E(\theta)^\top \bar{y} \simeq \frac{1}{S} \sum_{s=1}^S \partial_\theta h_\theta(z_s)^\top \bar{y}
104+
$$
91105

92106
### Catalogue
93107

94108
The following reparametrizations are implemented:
95109

96-
- Univariate Normal: ``X \sim \mathcal{N}(\mu, \sigma^2)`` is equivalent to ``X = \mu + \sigma Z`` with ``Z \sim \mathcal{N}(0, 1)``.
97-
- Multivariate Normal: ``X \sim \mathcal{N}(\mu, \Sigma)`` is equivalent to ``X = \mu + L Z`` with ``Z \sim \mathcal{N}(0, I)`` and ``L L^\top = \Sigma``. The matrix ``L`` can be obtained by Cholesky decomposition of ``\Sigma``.
110+
- Univariate Normal: $X \sim \mathcal{N}(\mu, \sigma^2)$ is equivalent to $X = \mu + \sigma Z$ with $Z \sim \mathcal{N}(0, 1)$.
111+
- Multivariate Normal: $X \sim \mathcal{N}(\mu, \Sigma)$ is equivalent to $X = \mu + L Z$ with $Z \sim \mathcal{N}(0, I)$ and $L L^\top = \Sigma$. The matrix $L$ can be obtained by Cholesky decomposition of $\Sigma$.
112+
113+
## Probability gradients
114+
115+
In addition to the expectation, we may also want gradients for individual output densities $q(y | \theta) = \mathbb{P}(f(X) = y)$.
116+
117+
### REINFORCE probability gradients
118+
119+
The REINFORCE technique can be applied in a similar way:
120+
121+
$$
122+
q(y | \theta) = \mathbb{E}[\mathbf{1}\{f(X) = y\}] = \int \mathbf{1} \{f(x) = y\} ~ p(x | \theta) ~ \mathrm{d}x
123+
$$
124+
125+
Differentiating through the integral,
126+
127+
$$
128+
\begin{aligned}
129+
\nabla_\theta q(y | \theta)
130+
& = \int \mathbf{1} \{f(x) = y\} ~ \nabla_\theta p(x | \theta) ~ \mathrm{d}x \\
131+
& = \mathbb{E} [\mathbf{1} \{f(X) = y\} ~ \nabla_\theta \log p(X | \theta)]
132+
\end{aligned}
133+
$$
134+
135+
The Monte-Carlo approximation for this is
136+
137+
$$
138+
\nabla_\theta q(y | \theta) \simeq \frac{1}{S} \sum_{s=1}^S \mathbf{1} \{f(x_s) = y\} ~ \nabla_\theta \log p(x_s | \theta)
139+
$$
140+
141+
### Reparametrization probability gradients
142+
143+
To leverage reparametrization, we perform a change of variables:
144+
145+
$$
146+
q(y | \theta) = \mathbb{E}[\mathbf{1}\{h_\theta(Z) = y\}] = \int \mathbf{1} \{h_\theta(z) = y\} ~ r(z) ~ \mathrm{d}z
147+
$$
148+
149+
Assuming that $h_\theta$ is invertible, we take $z = h_\theta^{-1}(u)$ and
150+
151+
$$
152+
\mathrm{d}z = |\partial h_{\theta}^{-1}(u)| ~ \mathrm{d}u
153+
$$
154+
155+
so that
156+
157+
$$
158+
q(y | \theta) = \int \mathbf{1} \{u = y\} ~ r(h_\theta^{-1}(u)) ~ |\partial h_{\theta}^{-1}(u)| ~ \mathrm{d}u
159+
$$
160+
161+
We can now differentiate, but it gets tedious.
98162

99163
## Bibliography
100164

101-
```@bibliography
102-
```
165+
$$@bibliography
166+
$$

src/DifferentiableExpectations.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using ChainRulesCore:
1919
rrule,
2020
rrule_via_ad,
2121
unthunk
22+
using Compat: @compat
2223
using DensityInterface: logdensityof
2324
using Distributions: Distribution, MvNormal, Normal
2425
using DocStringExtensions
@@ -29,14 +30,17 @@ using Statistics: Statistics, cov, mean, std
2930
using StatsBase: StatsBase
3031

3132
include("utils.jl")
33+
include("distribution.jl")
3234
include("abstract.jl")
3335
include("reinforce.jl")
3436
include("reparametrization.jl")
35-
include("distribution.jl")
3637

3738
export DifferentiableExpectation
3839
export Reinforce
3940
export Reparametrization
4041
export FixedAtomsProbabilityDistribution
42+
export empirical_distribution
43+
44+
@compat public atoms, weights
4145

4246
end # module DifferentiableExpectations

src/abstract.jl

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
DifferentiableExpectation{threaded}
2+
DifferentiableExpectation{t}
33
4-
Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages.
4+
Abstract supertype for differentiable parametric expectations `E : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages.
55
66
# Subtypes
77
@@ -10,9 +10,9 @@ Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(
1010
1111
# Calling behavior
1212
13-
(F::DifferentiableExpectation)(θ...; kwargs...)
13+
(E::DifferentiableExpectation)(θ...; kwargs...)
1414
15-
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
15+
Return a Monte-Carlo average `(1/S) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
1616
1717
# Type parameters
1818
@@ -32,49 +32,37 @@ The resulting object `dist` needs to satisfy:
3232
- the [Random API](https://docs.julialang.org/en/v1/stdlib/Random/#Hooking-into-the-Random-API) for sampling with `rand(rng, dist)`
3333
- the [DensityInterface.jl API](https://github.com/JuliaMath/DensityInterface.jl) for loglikelihoods with `logdensityof(dist, x)`
3434
"""
35-
abstract type DifferentiableExpectation{threaded} end
35+
abstract type DifferentiableExpectation{t} end
36+
37+
is_threaded(::DifferentiableExpectation{t}) where {t} = Val(t)
3638

3739
"""
38-
presamples(F::DifferentiableExpectation, θ...)
40+
empirical_predistribution(E::DifferentiableExpectation, θ...)
3941
40-
Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ ∼ p(θ)` are iid samples.
42+
Return a uniform [`FixedAtomsProbabilityDistribution`](@ref) over `{x₁, ..., xₛ}`, where the `xᵢ ∼ p(θ)` are iid samples.
4143
"""
42-
function presamples(F::DifferentiableExpectation, θ...)
43-
(; dist_constructor, rng, nb_samples, seed) = F
44+
function empirical_predistribution(E::DifferentiableExpectation, θ...)
45+
(; dist_constructor, rng, nb_samples, seed) = E
4446
dist = dist_constructor...)
4547
isnothing(seed) || seed!(rng, seed)
4648
xs = maybe_eachcol(rand(rng, dist, nb_samples))
47-
return xs
49+
xdist = FixedAtomsProbabilityDistribution(xs; threaded=unval(is_threaded(E)))
50+
return xdist
4851
end
4952

5053
"""
51-
samples(F::DifferentiableExpectation, θ...; kwargs...)
54+
empirical_distribution(E::DifferentiableExpectation, θ...; kwargs...)
5255
53-
Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples.
56+
Return a uniform [`FixedAtomsProbabilityDistribution`](@ref) over `{f(x₁), ..., f(xₛ)}`, where the `xᵢ ∼ p(θ)` are iid samples.
5457
"""
55-
function samples(F::DifferentiableExpectation{threaded}, θ...; kwargs...) where {threaded}
56-
xs = presamples(F, θ...)
57-
return samples_from_presamples(F, xs; kwargs...)
58-
end
59-
60-
function samples_from_presamples(
61-
F::DifferentiableExpectation{threaded}, xs::AbstractVector; kwargs...
62-
) where {threaded}
63-
(; f) = F
64-
fk = FixKwargs(f, kwargs)
65-
if threaded
66-
return tmap(fk, xs)
67-
else
68-
return map(fk, xs)
69-
end
58+
function empirical_distribution(E::DifferentiableExpectation, θ...; kwargs...)
59+
xdist = empirical_predistribution(E, θ...)
60+
fk = FixKwargs(E.f, kwargs)
61+
ydist = map(fk, xdist)
62+
return ydist
7063
end
7164

72-
function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded}
73-
ys = samples(F, θ...; kwargs...)
74-
y = if threaded
75-
tmean(ys)
76-
else
77-
mean(ys)
78-
end
79-
return y
65+
function (E::DifferentiableExpectation)(θ...; kwargs...)
66+
ydist = empirical_distribution(E, θ...; kwargs...)
67+
return mean(ydist)
8068
end

0 commit comments

Comments
 (0)