@@ -56,14 +56,13 @@ function estimate_energy_with_samples(prob, samples)
56
56
end
57
57
58
58
"""
59
- reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
59
+ reparam_with_entropy(rng, q, n_samples, ent_est)
60
60
61
61
Draw `n_samples` from `q` and compute its entropy.
62
62
63
63
# Arguments
64
64
- `rng::Random.AbstractRNG`: Random number generator.
65
65
- `q`: Variational approximation.
66
- - `q_stop`: `q` but with its gradient stopped.
67
66
- `n_samples::Int`: Number of Monte Carlo samples
68
67
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
69
68
@@ -72,7 +71,11 @@ Draw `n_samples` from `q` and compute its entropy.
72
71
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
73
72
"""
74
73
function reparam_with_entropy (
75
- rng:: Random.AbstractRNG , q, q_stop, n_samples:: Int , ent_est:: AbstractEntropyEstimator
74
+ rng :: Random.AbstractRNG ,
75
+ q,
76
+ q_stop,
77
+ n_samples:: Int ,
78
+ ent_est :: AbstractEntropyEstimator
76
79
)
77
80
samples = rand (rng, q, n_samples)
78
81
entropy = estimate_entropy_maybe_stl (ent_est, samples, q, q_stop)
94
97
estimate_objective (obj:: RepGradELBO , q, prob; n_samples:: Int = obj. n_samples) =
95
98
estimate_objective (Random. default_rng (), obj, q, prob; n_samples)
96
99
100
+ function estimate_repgradelbo_ad_forward (params′, aux)
101
+ @unpack rng, obj, problem, restructure, q_stop = aux
102
+ q = restructure (params′)
103
+ samples, entropy = reparam_with_entropy (rng, q, q_stop, obj. n_samples, obj. entropy)
104
+ energy = estimate_energy_with_samples (problem, samples)
105
+ elbo = energy + entropy
106
+ - elbo
107
+ end
108
+
97
109
function estimate_gradient! (
98
110
rng :: Random.AbstractRNG ,
99
111
obj :: RepGradELBO ,
100
112
adtype:: ADTypes.AbstractADType ,
101
113
out :: DiffResults.MutableDiffResult ,
102
114
prob,
103
- λ ,
115
+ params ,
104
116
restructure,
105
117
state,
106
118
)
107
- q_stop = restructure (λ)
108
- function f (λ′)
109
- q = restructure (λ′)
110
- samples, entropy = reparam_with_entropy (rng, q, q_stop, obj. n_samples, obj. entropy)
111
- energy = estimate_energy_with_samples (prob, samples)
112
- elbo = energy + entropy
113
- - elbo
114
- end
115
- value_and_gradient! (adtype, f, λ, out)
116
-
119
+ q_stop = restructure (params)
120
+ aux = (rng= rng, obj= obj, problem= prob, restructure= restructure, q_stop= q_stop)
121
+ value_and_gradient! (
122
+ adtype, estimate_repgradelbo_ad_forward, params, aux, out
123
+ )
117
124
nelbo = DiffResults. value (out)
118
125
stat = (elbo= - nelbo,)
119
-
120
126
out, nothing , stat
121
127
end
0 commit comments