Skip to content

Commit 936b45a

Browse files
authored
Remove non-Turing specific AD tests (#1487)
* Remove non-Turing specific AD tests * Remove unused method
1 parent abda5f7 commit 936b45a

File tree

1 file changed

+0
-185
lines changed

1 file changed

+0
-185
lines changed

test/core/ad.jl

Lines changed: 0 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ const FDM = FiniteDifferences
1212
dir = splitdir(splitdir(pathof(Turing))[1])[1]
1313
include(dir*"/test/test_utils/AllUtils.jl")
1414

15-
_to_cov(B) = B * B' + Matrix(I, size(B)...)
1615
@testset "ad.jl" begin
1716
@turing_testset "adr" begin
1817
ad_test_f = gdemo_default
@@ -49,20 +48,6 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
4948
∇E2 = gradient_logp(ZygoteAD(), x, vi, ad_test_f)[2]
5049
@test sort(∇E2) grad_FWAD atol=1e-9
5150
end
52-
@turing_testset "passing duals to distributions" begin
53-
float1 = 1.1
54-
float2 = 2.3
55-
d1 = Dual(1.1)
56-
d2 = Dual(2.3)
57-
58-
@test logpdf(Normal(0, 1), d1).value logpdf(Normal(0, 1), float1) atol=0.001
59-
@test logpdf(Gamma(2, 3), d2).value logpdf(Gamma(2, 3), float2) atol=0.001
60-
@test logpdf(Beta(2, 3), (d2 - d1) / 2).value logpdf(Beta(2, 3), (float2 - float1) / 2) atol=0.001
61-
62-
@test pdf(Normal(0, 1), d1).value pdf(Normal(0, 1), float1) atol=0.001
63-
@test pdf(Gamma(2, 3), d2).value pdf(Gamma(2, 3), float2) atol=0.001
64-
@test pdf(Beta(2, 3), (d2 - d1) / 2).value pdf(Beta(2, 3), (float2 - float1) / 2) atol=0.001
65-
end
6651
@turing_testset "general AD tests" begin
6752
# Tests gdemo gradient.
6853
function logp1(x::Vector)
@@ -93,176 +78,6 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
9378

9479
test_model_ad(wishart_ad(), logp3, [:v])
9580
end
96-
@turing_testset "Tracker, Zygote and ReverseDiff + logdet" begin
97-
rng, N = MersenneTwister(123456), 7
98-
ȳ, B = randn(rng), randn(rng, N, N)
99-
test_reverse_mode_ad(B->logdet(cholesky(_to_cov(B))), ȳ, B; rtol=1e-8, atol=1e-6)
100-
end
101-
@turing_testset "Tracker & Zygote + fill" begin
102-
rng = MersenneTwister(123456)
103-
test_reverse_mode_ad(x->fill(x, 7), randn(rng, 7), randn(rng))
104-
test_reverse_mode_ad(x->fill(x, 7, 11), randn(rng, 7, 11), randn(rng))
105-
test_reverse_mode_ad(x->fill(x, 7, 11, 13), rand(rng, 7, 11, 13), randn(rng))
106-
end
107-
@turing_testset "Tracker, Zygote and ReverseDiff + MvNormal" begin
108-
rng, N = MersenneTwister(123456), 11
109-
B = randn(rng, N, N)
110-
m, A = randn(rng, N), B' * B + I
111-
112-
# Generate from the TuringDenseMvNormal
113-
d, back = Tracker.forward(TuringDenseMvNormal, m, A)
114-
x = Tracker.data(rand(d))
115-
116-
# Check that the logpdf agrees with MvNormal.
117-
d_ref = MvNormal(m, PDMat(A))
118-
@test logpdf(d, x) logpdf(d_ref, x)
119-
120-
test_reverse_mode_ad((m, B, x)->logpdf(MvNormal(m, _to_cov(B)), x), randn(rng), m, B, x)
121-
end
122-
@turing_testset "Tracker, Zygote and ReverseDiff + Diagonal Normal" begin
123-
rng, N = MersenneTwister(123456), 11
124-
m, σ = randn(rng, N), exp.(0.1 .* randn(rng, N)) .+ 1
125-
126-
d = TuringDiagMvNormal(m, σ)
127-
x = rand(d)
128-
129-
# Check that the logpdf agrees with MvNormal.
130-
d_ref = MvNormal(m, σ)
131-
@test logpdf(d, x) logpdf(d_ref, x)
132-
133-
test_reverse_mode_ad((m, σ, x)->logpdf(MvNormal(m, σ), x), randn(rng), m, σ, x)
134-
end
135-
@turing_testset "Tracker, Zygote and ReverseDiff + MvNormal Interface" begin
136-
# Note that we only test methods where the `MvNormal` ctor actually constructs
137-
# a TuringDenseMvNormal.
138-
139-
rng, N = MersenneTwister(123456), 7
140-
m, b, B, x = randn(rng, N), randn(rng, N), randn(rng, N, N), randn(rng, N)
141-
= randn(rng)
142-
143-
# zero mean, dense covariance
144-
test_reverse_mode_ad((B, x)->logpdf(MvNormal(_to_cov(B)), x), randn(rng), B, x)
145-
test_reverse_mode_ad(B->logpdf(MvNormal(_to_cov(B)), x), randn(rng), B)
146-
147-
# zero mean, diagonal covariance
148-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(exp.(b)), x), randn(rng), b, x)
149-
test_reverse_mode_ad(b->logpdf(MvNormal(exp.(b)), x), randn(rng), b)
150-
151-
# dense mean, dense covariance
152-
test_reverse_mode_ad((m, B, x)->logpdf(MvNormal(m, _to_cov(B)), x),
153-
randn(rng),
154-
randn(rng, N), randn(rng, N, N), randn(rng, N),
155-
)
156-
test_reverse_mode_ad((m, B)->logpdf(MvNormal(m, _to_cov(B)), x),
157-
randn(rng),
158-
randn(rng, N), randn(rng, N, N),
159-
)
160-
test_reverse_mode_ad((m, x)->logpdf(MvNormal(m, _to_cov(B)), x),
161-
randn(rng),
162-
randn(rng, N), randn(rng, N),
163-
)
164-
test_reverse_mode_ad((B, x)->logpdf(MvNormal(m, _to_cov(B)), x),
165-
randn(rng),
166-
randn(rng, N, N), randn(rng, N),
167-
)
168-
test_reverse_mode_ad(m->logpdf(MvNormal(m, _to_cov(B)), x), randn(rng), randn(rng, N))
169-
test_reverse_mode_ad(B->logpdf(MvNormal(m, _to_cov(B)), x), randn(rng), randn(rng, N, N))
170-
171-
# dense mean, diagonal covariance
172-
test_reverse_mode_ad((m, b, x)->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
173-
randn(rng),
174-
randn(rng, N), randn(rng, N), randn(rng, N),
175-
)
176-
test_reverse_mode_ad((m, b)->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
177-
randn(rng),
178-
randn(rng, N), randn(rng, N),
179-
)
180-
test_reverse_mode_ad((m, x)->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
181-
randn(rng),
182-
randn(rng, N), randn(rng, N),
183-
)
184-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
185-
randn(rng),
186-
randn(rng, N), randn(rng, N),
187-
)
188-
test_reverse_mode_ad(m->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
189-
randn(rng),
190-
randn(rng, N),
191-
)
192-
test_reverse_mode_ad(b->logpdf(MvNormal(m, Diagonal(exp.(b))), x),
193-
randn(rng),
194-
randn(rng, N),
195-
)
196-
197-
# dense mean, diagonal variance
198-
test_reverse_mode_ad((m, b, x)->logpdf(MvNormal(m, exp.(b)), x),
199-
randn(rng),
200-
randn(rng, N), randn(rng, N), randn(rng, N),
201-
)
202-
test_reverse_mode_ad((m, b)->logpdf(MvNormal(m, exp.(b)), x),
203-
randn(rng),
204-
randn(rng, N), randn(rng, N),
205-
)
206-
test_reverse_mode_ad((m, x)->logpdf(MvNormal(m, exp.(b)), x),
207-
randn(rng),
208-
randn(rng, N), randn(rng, N),
209-
)
210-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(m, exp.(b)), x),
211-
randn(rng),
212-
randn(rng, N), randn(rng, N),
213-
)
214-
test_reverse_mode_ad(m->logpdf(MvNormal(m, exp.(b)), x), randn(rng), randn(rng, N))
215-
test_reverse_mode_ad(b->logpdf(MvNormal(m, exp.(b)), x), randn(rng), randn(rng, N))
216-
217-
# dense mean, constant covariance
218-
b_s = randn(rng)
219-
test_reverse_mode_ad((m, b, x)->logpdf(MvNormal(m, exp(b)), x),
220-
randn(rng),
221-
randn(rng, N), randn(rng), randn(rng, N),
222-
)
223-
test_reverse_mode_ad((m, b)->logpdf(MvNormal(m, exp(b)), x),
224-
randn(rng),
225-
randn(rng, N), randn(rng),
226-
)
227-
test_reverse_mode_ad((m, x)->logpdf(MvNormal(m, exp(b_s)), x),
228-
randn(rng),
229-
randn(rng, N), randn(rng, N)
230-
)
231-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(m, exp(b)), x),
232-
randn(rng),
233-
randn(rng), randn(rng, N),
234-
)
235-
test_reverse_mode_ad(m->logpdf(MvNormal(m, exp(b_s)), x), randn(rng), randn(rng, N))
236-
test_reverse_mode_ad(b->logpdf(MvNormal(m, exp(b)), x), randn(rng), randn(rng))
237-
238-
# dense mean, constant variance
239-
b_s = randn(rng)
240-
test_reverse_mode_ad((m, b, x)->logpdf(MvNormal(m, exp(b) * I), x),
241-
randn(rng),
242-
randn(rng, N), randn(rng), randn(rng, N),
243-
)
244-
test_reverse_mode_ad((m, b)->logpdf(MvNormal(m, exp(b) * I), x),
245-
randn(rng),
246-
randn(rng, N), randn(rng),
247-
)
248-
test_reverse_mode_ad((m, x)->logpdf(MvNormal(m, exp(b_s) * I), x),
249-
randn(rng),
250-
randn(rng, N), randn(rng, N),
251-
)
252-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(m, exp(b) * I), x),
253-
randn(rng),
254-
randn(rng), randn(rng, N),
255-
)
256-
test_reverse_mode_ad(m->logpdf(MvNormal(m, exp(b_s) * I), x), randn(rng), randn(rng, N))
257-
test_reverse_mode_ad(b->logpdf(MvNormal(m, exp(b) * I), x), randn(rng), randn(rng))
258-
259-
# zero mean, constant variance
260-
test_reverse_mode_ad((b, x)->logpdf(MvNormal(N, exp(b)), x),
261-
randn(rng),
262-
randn(rng), randn(rng, N),
263-
)
264-
test_reverse_mode_ad(b->logpdf(MvNormal(N, exp(b)), x), randn(rng), randn(rng))
265-
end
26681
@testset "Simplex Tracker, Zygote and ReverseDiff (with and without caching) AD" begin
26782
@model dir() = begin
26883
theta ~ Dirichlet(1 ./ fill(4, 4))

0 commit comments

Comments
 (0)