Skip to content

Commit 01352d4

Browse files
authored
Add broken test for Reinforce distribution rrule with variance reduction (#17)
1 parent d18dd88 commit 01352d4

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

test/expectation.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,11 @@ end
123123
r_split...) = mean(empirical_distribution(r, θ...))
124124
@test r(μ, σ) == r_split(μ, σ)
125125
@test gradient(r, μ, σ) == gradient(r_split, μ, σ)
126+
127+
r = Reinforce(
128+
exp, Normal; nb_samples=100, variance_reduction=true, rng=StableRNG(seed), seed=0
129+
)
130+
r_split...) = mean(empirical_distribution(r, θ...))
131+
@test r(μ, σ) == r_split(μ, σ)
132+
@test_broken gradient(r, μ, σ) == gradient(r_split, μ, σ)
126133
end

0 commit comments

Comments
 (0)