Skip to content

Commit fb0d118

Browse files
BatyLeogdalle
andauthored
Add variance reduction for proba dist rrule (#20)
* Support keyword arguments in method applied to FixedAtomProbabilityDistribution with function f * update docs * Typos * remove kwargs from mean --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent 76298d3 commit fb0d118

File tree

5 files changed

+85
-55
lines changed

5 files changed

+85
-55
lines changed

docs/src/DiffExp.bib

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,50 @@
11
@misc{blondelElementsDifferentiableProgramming2024,
2-
title = {The {{Elements}} of {{Differentiable Programming}}},
3-
author = {Blondel, Mathieu and Roulet, Vincent},
4-
year = {2024},
5-
month = mar,
6-
number = {arXiv:2403.14606},
7-
eprint = {2403.14606},
8-
primaryclass = {cs},
9-
publisher = {arXiv},
10-
doi = {10.48550/arXiv.2403.14606},
11-
url = {http://arxiv.org/abs/2403.14606},
12-
urldate = {2024-03-22},
13-
abstract = {Artificial intelligence has recently experienced remarkable advances, fueled by large models, vast datasets, accelerated hardware, and, last but not least, the transformative power of differentiable programming. This new programming paradigm enables end-to-end differentiation of complex computer programs (including those with control flows and data structures), making gradient-based optimization of program parameters possible. As an emerging paradigm, differentiable programming builds upon several areas of computer science and applied mathematics, including automatic differentiation, graphical models, optimization and statistics. This book presents a comprehensive review of the fundamental concepts useful for differentiable programming. We adopt two main perspectives, that of optimization and that of probability, with clear analogies between the two. Differentiable programming is not merely the differentiation of programs, but also the thoughtful design of programs intended for differentiation. By making programs differentiable, we inherently introduce probability distributions over their execution, providing a means to quantify the uncertainty associated with program outputs.},
14-
archiveprefix = {arXiv},
2+
title = {The {{Elements}} of {{Differentiable Programming}}},
3+
author = {Blondel, Mathieu and Roulet, Vincent},
4+
year = {2024},
5+
month = mar,
6+
number = {arXiv:2403.14606},
7+
eprint = {2403.14606},
8+
primaryclass = {cs},
9+
publisher = {arXiv},
10+
doi = {10.48550/arXiv.2403.14606},
11+
url = {http://arxiv.org/abs/2403.14606},
12+
urldate = {2024-03-22},
13+
abstract = {Artificial intelligence has recently experienced remarkable advances, fueled by large models, vast datasets, accelerated hardware, and, last but not least, the transformative power of differentiable programming. This new programming paradigm enables end-to-end differentiation of complex computer programs (including those with control flows and data structures), making gradient-based optimization of program parameters possible. As an emerging paradigm, differentiable programming builds upon several areas of computer science and applied mathematics, including automatic differentiation, graphical models, optimization and statistics. This book presents a comprehensive review of the fundamental concepts useful for differentiable programming. We adopt two main perspectives, that of optimization and that of probability, with clear analogies between the two. Differentiable programming is not merely the differentiation of programs, but also the thoughtful design of programs intended for differentiation. By making programs differentiable, we inherently introduce probability distributions over their execution, providing a means to quantify the uncertainty associated with program outputs.},
14+
archiveprefix = {arXiv}
1515
}
1616
% == BibTeX quality report for blondelElementsDifferentiableProgramming2024:
1717
% ? Title looks like it was stored in title-case in Zotero
1818
1919
@article{koolBuyREINFORCESamples2022,
20-
title = {Buy 4 {{REINFORCE Samples}}, {{Get}} a {{Baseline}} for {{Free}}!},
21-
author = {Kool, Wouter and van Hoof, Herke and Welling, Max},
22-
year = {2022},
23-
month = jul,
24-
url = {https://openreview.net/forum?id=r1lgTGL5DE},
25-
urldate = {2023-04-17},
20+
title = {Buy 4 {{REINFORCE Samples}}, {{Get}} a {{Baseline}} for {{Free}}!},
21+
author = {Kool, Wouter and van Hoof, Herke and Welling, Max},
22+
year = {2022},
23+
month = jul,
24+
journal = {ICLR},
25+
url = {https://openreview.net/forum?id=r1lgTGL5DE},
26+
urldate = {2023-04-17},
2627
abstract = {REINFORCE can be used to train models in structured prediction settings to directly optimize the test-time objective. However, the common case of sampling one prediction per datapoint (input) is data-inefficient. We show that by drawing multiple samples (predictions) per datapoint, we can learn with significantly less data, as we freely obtain a REINFORCE baseline to reduce variance. Additionally we derive a REINFORCE estimator with baseline, based on sampling without replacement. Combined with a recent technique to sample sequences without replacement using Stochastic Beam Search, this improves the training procedure for a sequence model that predicts the solution to the Travelling Salesman Problem.},
27-
langid = {english},
28-
language = {en},
28+
langid = {english},
29+
language = {en}
2930
}
3031
% == BibTeX quality report for koolBuyREINFORCESamples2022:
3132
% Missing required field 'journal'
3233
% ? Title looks like it was stored in title-case in Zotero
3334
% ? unused Library catalog ("openreview.net")
3435
3536
@article{mohamedMonteCarloGradient2020,
36-
title = {Monte {{Carlo Gradient Estimation}} in {{Machine Learning}}},
37-
author = {Mohamed, Shakir and Rosca, Mihaela and Figurnov, Michael and Mnih, Andriy},
38-
year = {2020},
39-
journal = {Journal of Machine Learning Research},
40-
volume = {21},
41-
number = {132},
42-
pages = {1--62},
43-
issn = {1533-7928},
44-
url = {http://jmlr.org/papers/v21/19-346.html},
45-
urldate = {2022-10-21},
46-
abstract = {This paper is a broad and accessible survey of the methods we have at our disposal for Monte Carlo gradient estimation in machine learning and across the statistical sciences: the problem of computing the gradient of an expectation of a function with respect to parameters defining the distribution that is integrated; the problem of sensitivity analysis. In machine learning research, this gradient problem lies at the core of many learning problems, in supervised, unsupervised and reinforcement learning. We will generally seek to rewrite such gradients in a form that allows for Monte Carlo estimation, allowing them to be easily and efficiently used and analysed. We explore three strategies---the pathwise, score function, and measure-valued gradient estimators---exploring their historical development, derivation, and underlying assumptions. We describe their use in other fields, show how they are related and can be combined, and expand on their possible generalisations. Wherever Monte Carlo gradient estimators have been derived and deployed in the past, important advances have followed. A deeper and more widely-held understanding of this problem will lead to further advances, and it is these advances that we wish to support.},
37+
title = {Monte {{Carlo Gradient Estimation}} in {{Machine Learning}}},
38+
author = {Mohamed, Shakir and Rosca, Mihaela and Figurnov, Michael and Mnih, Andriy},
39+
year = {2020},
40+
journal = {Journal of Machine Learning Research},
41+
volume = {21},
42+
number = {132},
43+
pages = {1--62},
44+
issn = {1533-7928},
45+
url = {http://jmlr.org/papers/v21/19-346.html},
46+
urldate = {2022-10-21},
47+
abstract = {This paper is a broad and accessible survey of the methods we have at our disposal for Monte Carlo gradient estimation in machine learning and across the statistical sciences: the problem of computing the gradient of an expectation of a function with respect to parameters defining the distribution that is integrated; the problem of sensitivity analysis. In machine learning research, this gradient problem lies at the core of many learning problems, in supervised, unsupervised and reinforcement learning. We will generally seek to rewrite such gradients in a form that allows for Monte Carlo estimation, allowing them to be easily and efficiently used and analysed. We explore three strategies---the pathwise, score function, and measure-valued gradient estimators---exploring their historical development, derivation, and underlying assumptions. We describe their use in other fields, show how they are related and can be combined, and expand on their possible generalisations. Wherever Monte Carlo gradient estimators have been derived and deployed in the past, important advances have followed. A deeper and more widely-held understanding of this problem will lead to further advances, and it is these advances that we wish to support.}
4748
}
4849
% == BibTeX quality report for mohamedMonteCarloGradient2020:
4950
% ? Title looks like it was stored in title-case in Zotero

docs/src/background.md

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Most of the math below is taken from [mohamedMonteCarloGradient2020](@citet).
55
Consider a function $f: \mathbb{R}^n \to \mathbb{R}^m$, a parameter $\theta \in \mathbb{R}^d$ and a parametric probability distribution $p(\theta)$ on the input space.
66
Given a random variable $X \sim p(\theta)$, we want to differentiate the expectation of $Y = f(X)$ with respect to $\theta$:
77

8-
$$E(\theta) = \mathbb{E}[f(X)] = \int f(x) ~ p(x | \theta) ~\mathrm{d} x$$
8+
$$E(\theta) = \mathbb{E}[f(X)] = \int f(x) ~ p(x | \theta) ~\mathrm{d} x = \int y ~ q(y | \theta) ~\mathrm{d} y$$
99

1010
Usually this is approximated with Monte-Carlo sampling: let $x_1, \dots, x_S \sim p(\theta)$ be i.i.d., we have the estimator
1111

@@ -15,7 +15,7 @@ $$E(\theta) \simeq \frac{1}{S} \sum_{s=1}^S f(x_s)$$
1515

1616
Since $E$ is a vector-to-vector function, the key quantity we want to compute is its Jacobian matrix $\partial E(\theta) \in \mathbb{R}^{m \times n}$:
1717

18-
$$\partial E(\theta) = \int y ~ \nabla_\theta q(y | \theta)^\top ~ \mathrm{d} y = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~\mathrm{d} x$$
18+
$$\partial E(\theta) = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~\mathrm{d} x = \int y ~ \nabla_\theta q(y | \theta)^\top ~ \mathrm{d} y$$
1919

2020
However, to implement automatic differentiation, we only need the vector-Jacobian product (VJP) $\partial E(\theta)^\top \bar{y}$ with an output cotangent $\bar{y} \in \mathbb{R}^m$.
2121
See the book by [blondelElementsDifferentiableProgramming2024](@citet) to know more.
@@ -33,7 +33,7 @@ The REINFORCE estimator is derived with the help of the identity $\nabla \log u
3333
$$\begin{aligned}
3434
\partial E(\theta)
3535
& = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~ \mathrm{d}x \\
36-
& = \int f(x) ~ p(x | \theta) \nabla_\theta \log p(x | \theta)^\top ~ \mathrm{d}x \\
36+
& = \int f(x) ~ \nabla_\theta \log p(x | \theta)^\top p(x | \theta) ~ \mathrm{d}x \\
3737
& = \mathbb{E} \left[f(X) \nabla_\theta \log p(X | \theta)^\top\right] \\
3838
\end{aligned}$$
3939

@@ -53,7 +53,7 @@ For $S > 1$ Monte-Carlo samples, we have
5353

5454
$$\begin{aligned}
5555
\partial E(\theta)^\top \bar{y}
56-
& \simeq \frac{1}{S} \sum_{s=1}^S \left(f(x_s) - \frac{1}{S - 1}\sum_{j\neq i} f(x_j) \right)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)\\
56+
& \simeq \frac{1}{S} \sum_{s=1}^S \left(f(x_s) - \frac{1}{S - 1}\sum_{j\neq s} f(x_j) \right)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)\\
5757
& = \frac{1}{S - 1}\sum_{s=1}^S (f(x_s) - b)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)
5858
\end{aligned}$$
5959

@@ -90,38 +90,55 @@ The following reparametrizations are implemented:
9090

9191
## Probability gradients
9292

93-
In addition to the expectation, we may also want gradients for individual output densities $q(y | \theta) = \mathbb{P}(f(X) = y)$.
93+
In the case where $f$ is a function that takes values in a finite set $\mathcal{Y} = \{y_1, \cdots, y_K\}$, we may also want to compute the jacobian of the probability weights vector:
94+
95+
$$q : \theta \longmapsto \begin{pmatrix} q(y_1|\theta) = \mathbb{P}(f(X) = y_1|\theta) \\ \dots \\ q(y_K|\theta) = \mathbb{P}(f(X) = y_K|\theta) \end{pmatrix}$$
96+
97+
whose Jacobian is given by
98+
99+
$$\partial_\theta q(\theta) = \begin{pmatrix} \nabla_\theta q(y_1|\theta)^\top \\ \dots \\ \nabla_\theta q(y_K|\theta)^\top \end{pmatrix}$$
94100

95101
### REINFORCE probability gradients
96102

97103
The REINFORCE technique can be applied in a similar way:
98104

99-
$$q(y | \theta) = \mathbb{E}[\mathbf{1}\{f(X) = y\}] = \int \mathbf{1} \{f(x) = y\} ~ p(x | \theta) ~ \mathrm{d}x$$
105+
$$q(y_k | \theta) = \mathbb{E}[\mathbf{1}\{f(X) = y_k\}] = \int \mathbf{1} \{f(x) = y_k\} ~ p(x | \theta) ~ \mathrm{d}x$$
100106

101107
Differentiating through the integral,
102108

103109
$$\begin{aligned}
104-
\nabla_\theta q(y | \theta)
105-
& = \int \mathbf{1} \{f(x) = y\} ~ \nabla_\theta p(x | \theta) ~ \mathrm{d}x \\
106-
& = \mathbb{E} [\mathbf{1} \{f(X) = y\} ~ \nabla_\theta \log p(X | \theta)]
110+
\nabla_\theta q(y_k | \theta)
111+
& = \int \mathbf{1} \{f(x) = y_k\} ~ \nabla_\theta p(x | \theta) ~ \mathrm{d}x \\
112+
& = \mathbb{E} [\mathbf{1} \{f(X) = y_k\} ~ \nabla_\theta \log p(X | \theta)]
107113
\end{aligned}$$
108114

109115
The Monte-Carlo approximation for this is
110116

111-
$$\nabla_\theta q(y | \theta) \simeq \frac{1}{S} \sum_{s=1}^S \mathbf{1} \{f(x_s) = y\} ~ \nabla_\theta \log p(x_s | \theta)$$
117+
$$\nabla_\theta q(y_k | \theta) \simeq \frac{1}{S} \sum_{s=1}^S \mathbf{1} \{f(x_s) = y_k\} ~ \nabla_\theta \log p(x_s | \theta)$$
118+
119+
The VJP is then
112120

113-
In our implementation, we assume that the sampled $y_s$ are pairwise distinct (maybe not necessary?), and that together they form the whole support of the distribution $q$.
114-
We can thus consider the vector-to-vector mapping
121+
$$\begin{aligned}
122+
\partial_\theta q(\theta)^\top \bar{q} &= \sum_{k=1}^K \bar{q}_k \nabla_\theta q(y_k | \theta)\\
123+
&\simeq \frac{1}{S} \sum_{s=1}^S \left[\sum_{k=1}^K \bar{q}_k \mathbf{1} \{f(x_s) = y_k\}\right] ~ \nabla_\theta \log p(x_s | \theta)
124+
\end{aligned}$$
115125

116-
$$q : \theta \longmapsto \begin{pmatrix} q(y_1|\theta) \\ \dots \\ q(y_S | \theta) \end{pmatrix}$$
126+
In our implementation, the [`empirical_distribution`](@ref) method outputs an empirical [`FixedAtomsProbabilityDistribution`](@ref) with uniform weights $\frac{1}{S}$, where some $x_s$ can be repeated.
117127

118-
whose Jacobian is given by
128+
$$q : \theta \longmapsto \begin{pmatrix} q(f(x_1)|\theta) \\ \dots \\ q(f(x_S) | \theta) \end{pmatrix}$$
129+
130+
We therefore define the corresponding VJP as
131+
132+
$$\partial_\theta q(\theta)^\top \bar{q} = \frac{1}{S} \sum_{s=1}^S \bar{q}_s \nabla_\theta \log p(x_s | \theta)$$
119133

120-
$$\partial_\theta q(\theta) = \frac{1}{S} \begin{pmatrix} \nabla_\theta \log p(x_1 | \theta)^\top \\ \dots \\ \nabla_\theta \log p(x_S | \theta)^\top \end{pmatrix}$$
134+
If $\bar q$ comes from `mean`, we have $\bar q_s = f(x_s)^\top \bar y$ and we obtain the REINFORCE VJP.
121135

122-
and whose VJP is given by
136+
This VJP can be interpreted as an empirical expectation, to which we can also apply variance reduction:
137+
$$\partial_\theta q(\theta)^\top \bar q \approx \frac{1}{S-1}\sum_s(\bar q_s - b') \nabla_\theta \log p(x_s|\theta)$$
138+
with $b' = \frac{1}{S}\sum_s \bar q_s$.
123139

124-
$$\partial_\theta q(\theta)^\top \bar{q} = \frac{1}{S} \sum_s \bar{q}_s \nabla_\theta \log p(x_s | \theta)$$
140+
Again, if $\bar q$ comes from `mean`, we have $\bar q_s = f(x_s)^\top \bar y$ and $b' = b^\top \bar y$. We then obtain the REINFORCE backward rule with variance reduction:
141+
$$\partial_\theta q(\theta)^\top \bar q \approx \frac{1}{S-1}\sum_s(f(x_s) - b)^\top \bar y \nabla_\theta \log p(x_s|\theta)$$
125142

126143
### Reparametrization probability gradients
127144

src/reinforce.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,12 @@ function ChainRulesCore.rrule(
124124
end
125125

126126
function ChainRulesCore.rrule(
127-
rc::RuleConfig, ::typeof(empirical_distribution), E::Reinforce, θ...; kwargs...
128-
)
127+
rc::RuleConfig,
128+
::typeof(empirical_distribution),
129+
E::Reinforce{t,variance_reduction},
130+
θ...;
131+
kwargs...,
132+
) where {t,variance_reduction}
129133
project_θ = ProjectTo(θ)
130134

131135
(; f, nb_samples) = E
@@ -137,12 +141,22 @@ function ChainRulesCore.rrule(
137141
_dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, E, x, θ...)
138142
gs = mymap(is_threaded(E), _dist_logdensity_grad_partial, xs)
139143

144+
adjusted_nb_samples = nb_samples - (variance_reduction && nb_samples > 1)
145+
140146
function pullback_Reinforce_probadist(Δdist_thunked)
141147
Δdist = unthunk(Δdist_thunked)
142148
Δps = Δdist.weights
149+
Δps_mean = mean(Δps)
150+
Δps_baseline = if (variance_reduction && nb_samples > 1)
151+
Δps .- Δps_mean
152+
else
153+
Δps
154+
end
143155
ΔE = @not_implemented("The fields of the `Reinforce` object are constant.")
144156
_single_sample_pullback(gᵢ, Δpᵢ) = gᵢ .* Δpᵢ
145-
Δθ = mymapreduce(is_threaded(E), _single_sample_pullback, .+, gs, Δps) ./ nb_samples
157+
Δθ =
158+
mymapreduce(is_threaded(E), _single_sample_pullback, .+, gs, Δps_baseline) ./
159+
adjusted_nb_samples
146160
Δθ_proj = project_θ(Δθ)
147161
return (NoTangent(), ΔE, Δθ_proj...)
148162
end

test/distribution.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ rng = StableRNG(63)
1717
for threaded in (false, true)
1818
dist = FixedAtomsProbabilityDistribution([2, 3], [0.4, 0.6]; threaded)
1919

20-
string(dist)
21-
2220
@test length(dist) == 2
2321

2422
@test mean(dist) 2.6

test/expectation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,5 @@ end
129129
)
130130
r_split...) = mean(empirical_distribution(r, θ...))
131131
@test r(μ, σ) == r_split(μ, σ)
132-
@test_broken gradient(r, μ, σ) == gradient(r_split, μ, σ)
132+
@test all(isapprox.(gradient(r, μ, σ), gradient(r_split, μ, σ); atol=1e-10))
133133
end

0 commit comments

Comments
 (0)