@@ -12,7 +12,6 @@ const FDM = FiniteDifferences
12
12
dir = splitdir (splitdir (pathof (Turing))[1 ])[1 ]
13
13
include (dir* " /test/test_utils/AllUtils.jl" )
14
14
15
- _to_cov (B) = B * B' + Matrix (I, size (B)... )
16
15
@testset " ad.jl" begin
17
16
@turing_testset " adr" begin
18
17
ad_test_f = gdemo_default
@@ -49,20 +48,6 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
49
48
∇E2 = gradient_logp (ZygoteAD (), x, vi, ad_test_f)[2 ]
50
49
@test sort (∇E2) ≈ grad_FWAD atol= 1e-9
51
50
end
52
- @turing_testset " passing duals to distributions" begin
53
- float1 = 1.1
54
- float2 = 2.3
55
- d1 = Dual (1.1 )
56
- d2 = Dual (2.3 )
57
-
58
- @test logpdf (Normal (0 , 1 ), d1). value ≈ logpdf (Normal (0 , 1 ), float1) atol= 0.001
59
- @test logpdf (Gamma (2 , 3 ), d2). value ≈ logpdf (Gamma (2 , 3 ), float2) atol= 0.001
60
- @test logpdf (Beta (2 , 3 ), (d2 - d1) / 2 ). value ≈ logpdf (Beta (2 , 3 ), (float2 - float1) / 2 ) atol= 0.001
61
-
62
- @test pdf (Normal (0 , 1 ), d1). value ≈ pdf (Normal (0 , 1 ), float1) atol= 0.001
63
- @test pdf (Gamma (2 , 3 ), d2). value ≈ pdf (Gamma (2 , 3 ), float2) atol= 0.001
64
- @test pdf (Beta (2 , 3 ), (d2 - d1) / 2 ). value ≈ pdf (Beta (2 , 3 ), (float2 - float1) / 2 ) atol= 0.001
65
- end
66
51
@turing_testset " general AD tests" begin
67
52
# Tests gdemo gradient.
68
53
function logp1 (x:: Vector )
@@ -93,176 +78,6 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
93
78
94
79
test_model_ad (wishart_ad (), logp3, [:v ])
95
80
end
96
- @turing_testset " Tracker, Zygote and ReverseDiff + logdet" begin
97
- rng, N = MersenneTwister (123456 ), 7
98
- ȳ, B = randn (rng), randn (rng, N, N)
99
- test_reverse_mode_ad (B-> logdet (cholesky (_to_cov (B))), ȳ, B; rtol= 1e-8 , atol= 1e-6 )
100
- end
101
- @turing_testset " Tracker & Zygote + fill" begin
102
- rng = MersenneTwister (123456 )
103
- test_reverse_mode_ad (x-> fill (x, 7 ), randn (rng, 7 ), randn (rng))
104
- test_reverse_mode_ad (x-> fill (x, 7 , 11 ), randn (rng, 7 , 11 ), randn (rng))
105
- test_reverse_mode_ad (x-> fill (x, 7 , 11 , 13 ), rand (rng, 7 , 11 , 13 ), randn (rng))
106
- end
107
- @turing_testset " Tracker, Zygote and ReverseDiff + MvNormal" begin
108
- rng, N = MersenneTwister (123456 ), 11
109
- B = randn (rng, N, N)
110
- m, A = randn (rng, N), B' * B + I
111
-
112
- # Generate from the TuringDenseMvNormal
113
- d, back = Tracker. forward (TuringDenseMvNormal, m, A)
114
- x = Tracker. data (rand (d))
115
-
116
- # Check that the logpdf agrees with MvNormal.
117
- d_ref = MvNormal (m, PDMat (A))
118
- @test logpdf (d, x) ≈ logpdf (d_ref, x)
119
-
120
- test_reverse_mode_ad ((m, B, x)-> logpdf (MvNormal (m, _to_cov (B)), x), randn (rng), m, B, x)
121
- end
122
- @turing_testset " Tracker, Zygote and ReverseDiff + Diagonal Normal" begin
123
- rng, N = MersenneTwister (123456 ), 11
124
- m, σ = randn (rng, N), exp .(0.1 .* randn (rng, N)) .+ 1
125
-
126
- d = TuringDiagMvNormal (m, σ)
127
- x = rand (d)
128
-
129
- # Check that the logpdf agrees with MvNormal.
130
- d_ref = MvNormal (m, σ)
131
- @test logpdf (d, x) ≈ logpdf (d_ref, x)
132
-
133
- test_reverse_mode_ad ((m, σ, x)-> logpdf (MvNormal (m, σ), x), randn (rng), m, σ, x)
134
- end
135
- @turing_testset " Tracker, Zygote and ReverseDiff + MvNormal Interface" begin
136
- # Note that we only test methods where the `MvNormal` ctor actually constructs
137
- # a TuringDenseMvNormal.
138
-
139
- rng, N = MersenneTwister (123456 ), 7
140
- m, b, B, x = randn (rng, N), randn (rng, N), randn (rng, N, N), randn (rng, N)
141
- ȳ = randn (rng)
142
-
143
- # zero mean, dense covariance
144
- test_reverse_mode_ad ((B, x)-> logpdf (MvNormal (_to_cov (B)), x), randn (rng), B, x)
145
- test_reverse_mode_ad (B-> logpdf (MvNormal (_to_cov (B)), x), randn (rng), B)
146
-
147
- # zero mean, diagonal covariance
148
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (exp .(b)), x), randn (rng), b, x)
149
- test_reverse_mode_ad (b-> logpdf (MvNormal (exp .(b)), x), randn (rng), b)
150
-
151
- # dense mean, dense covariance
152
- test_reverse_mode_ad ((m, B, x)-> logpdf (MvNormal (m, _to_cov (B)), x),
153
- randn (rng),
154
- randn (rng, N), randn (rng, N, N), randn (rng, N),
155
- )
156
- test_reverse_mode_ad ((m, B)-> logpdf (MvNormal (m, _to_cov (B)), x),
157
- randn (rng),
158
- randn (rng, N), randn (rng, N, N),
159
- )
160
- test_reverse_mode_ad ((m, x)-> logpdf (MvNormal (m, _to_cov (B)), x),
161
- randn (rng),
162
- randn (rng, N), randn (rng, N),
163
- )
164
- test_reverse_mode_ad ((B, x)-> logpdf (MvNormal (m, _to_cov (B)), x),
165
- randn (rng),
166
- randn (rng, N, N), randn (rng, N),
167
- )
168
- test_reverse_mode_ad (m-> logpdf (MvNormal (m, _to_cov (B)), x), randn (rng), randn (rng, N))
169
- test_reverse_mode_ad (B-> logpdf (MvNormal (m, _to_cov (B)), x), randn (rng), randn (rng, N, N))
170
-
171
- # dense mean, diagonal covariance
172
- test_reverse_mode_ad ((m, b, x)-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
173
- randn (rng),
174
- randn (rng, N), randn (rng, N), randn (rng, N),
175
- )
176
- test_reverse_mode_ad ((m, b)-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
177
- randn (rng),
178
- randn (rng, N), randn (rng, N),
179
- )
180
- test_reverse_mode_ad ((m, x)-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
181
- randn (rng),
182
- randn (rng, N), randn (rng, N),
183
- )
184
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
185
- randn (rng),
186
- randn (rng, N), randn (rng, N),
187
- )
188
- test_reverse_mode_ad (m-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
189
- randn (rng),
190
- randn (rng, N),
191
- )
192
- test_reverse_mode_ad (b-> logpdf (MvNormal (m, Diagonal (exp .(b))), x),
193
- randn (rng),
194
- randn (rng, N),
195
- )
196
-
197
- # dense mean, diagonal variance
198
- test_reverse_mode_ad ((m, b, x)-> logpdf (MvNormal (m, exp .(b)), x),
199
- randn (rng),
200
- randn (rng, N), randn (rng, N), randn (rng, N),
201
- )
202
- test_reverse_mode_ad ((m, b)-> logpdf (MvNormal (m, exp .(b)), x),
203
- randn (rng),
204
- randn (rng, N), randn (rng, N),
205
- )
206
- test_reverse_mode_ad ((m, x)-> logpdf (MvNormal (m, exp .(b)), x),
207
- randn (rng),
208
- randn (rng, N), randn (rng, N),
209
- )
210
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (m, exp .(b)), x),
211
- randn (rng),
212
- randn (rng, N), randn (rng, N),
213
- )
214
- test_reverse_mode_ad (m-> logpdf (MvNormal (m, exp .(b)), x), randn (rng), randn (rng, N))
215
- test_reverse_mode_ad (b-> logpdf (MvNormal (m, exp .(b)), x), randn (rng), randn (rng, N))
216
-
217
- # dense mean, constant covariance
218
- b_s = randn (rng)
219
- test_reverse_mode_ad ((m, b, x)-> logpdf (MvNormal (m, exp (b)), x),
220
- randn (rng),
221
- randn (rng, N), randn (rng), randn (rng, N),
222
- )
223
- test_reverse_mode_ad ((m, b)-> logpdf (MvNormal (m, exp (b)), x),
224
- randn (rng),
225
- randn (rng, N), randn (rng),
226
- )
227
- test_reverse_mode_ad ((m, x)-> logpdf (MvNormal (m, exp (b_s)), x),
228
- randn (rng),
229
- randn (rng, N), randn (rng, N)
230
- )
231
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (m, exp (b)), x),
232
- randn (rng),
233
- randn (rng), randn (rng, N),
234
- )
235
- test_reverse_mode_ad (m-> logpdf (MvNormal (m, exp (b_s)), x), randn (rng), randn (rng, N))
236
- test_reverse_mode_ad (b-> logpdf (MvNormal (m, exp (b)), x), randn (rng), randn (rng))
237
-
238
- # dense mean, constant variance
239
- b_s = randn (rng)
240
- test_reverse_mode_ad ((m, b, x)-> logpdf (MvNormal (m, exp (b) * I), x),
241
- randn (rng),
242
- randn (rng, N), randn (rng), randn (rng, N),
243
- )
244
- test_reverse_mode_ad ((m, b)-> logpdf (MvNormal (m, exp (b) * I), x),
245
- randn (rng),
246
- randn (rng, N), randn (rng),
247
- )
248
- test_reverse_mode_ad ((m, x)-> logpdf (MvNormal (m, exp (b_s) * I), x),
249
- randn (rng),
250
- randn (rng, N), randn (rng, N),
251
- )
252
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (m, exp (b) * I), x),
253
- randn (rng),
254
- randn (rng), randn (rng, N),
255
- )
256
- test_reverse_mode_ad (m-> logpdf (MvNormal (m, exp (b_s) * I), x), randn (rng), randn (rng, N))
257
- test_reverse_mode_ad (b-> logpdf (MvNormal (m, exp (b) * I), x), randn (rng), randn (rng))
258
-
259
- # zero mean, constant variance
260
- test_reverse_mode_ad ((b, x)-> logpdf (MvNormal (N, exp (b)), x),
261
- randn (rng),
262
- randn (rng), randn (rng, N),
263
- )
264
- test_reverse_mode_ad (b-> logpdf (MvNormal (N, exp (b)), x), randn (rng), randn (rng))
265
- end
266
81
@testset " Simplex Tracker, Zygote and ReverseDiff (with and without caching) AD" begin
267
82
@model dir () = begin
268
83
theta ~ Dirichlet (1 ./ fill (4 , 4 ))
0 commit comments