1
- @testset " AD correctness " begin
2
- f (x) = sum (abs2, x)
1
+ @testset " DI. AD with context wrapper " begin
2
+ f (x, y, z ) = sum (abs2, x .+ y .+ z )
3
3
4
4
@testset " $T " for T in [Float32, Float64]
5
5
x = randn (T, 10 )
6
+ y = randn (T, 10 )
7
+ z = randn (T, 10 )
6
8
chunksize = size (x, 1 )
7
9
8
10
@testset " $at " for at in [
9
11
ADTypes. AutoZygote (),
10
12
ADTypes. AutoForwardDiff (; chunksize= chunksize),
11
13
ADTypes. AutoForwardDiff (),
12
14
ADTypes. AutoReverseDiff (false ),
13
- ADTypes. AutoEnzyme ( ),
15
+ ADTypes. AutoMooncake (; config = ADTypes . Mooncake . Config () ),
14
16
]
15
- out = DiffResults. GradientResult (x)
16
- NormalizingFlows. value_and_gradient! (at, f, x, out)
17
- @test DiffResults. value (out) ≈ f (x)
18
- @test DiffResults. gradient (out) ≈ 2 x
17
+ value, grad = NormalizingFlows. _value_and_gradient (f, at, x, y, z)
18
+ @test DiffResults. value (out) ≈ f (x, y, z)
19
+ @test DiffResults. gradient (out) ≈ 2 * (x .+ y .+ z)
19
20
end
20
21
end
21
22
end
25
26
ADTypes. AutoZygote (),
26
27
ADTypes. AutoForwardDiff (),
27
28
ADTypes. AutoReverseDiff (false ),
28
- # ADTypes.AutoEnzyme(), # not working now
29
+ ADTypes . AutoMooncake (; config = ADTypes. Mooncake . Config ()),
29
30
]
30
31
@testset " $T " for T in [Float32, Float64]
31
32
μ = 10 * ones (T, 2 )
38
39
39
40
sample_per_iter = 10
40
41
θ, re = Optimisers. destructure (flow)
41
- out = DiffResults. GradientResult (θ)
42
42
43
43
# check grad computation for elbo
44
- NormalizingFlows. grad! (
45
- Random. default_rng (), at, elbo, θ, re, out, logp, sample_per_iter
44
+ loss (θ, args... ) = - NormalizingFlows. elbo (re (θ), args... )
45
+ value, grad = NormalizingFlows. _value_and_gradient (
46
+ loss, at, θ, logp, randn (T, 2 , sample_per_iter)
46
47
)
47
48
48
- @test DiffResults . value (out) != nothing
49
- @test all (DiffResults . gradient (out) .!= nothing )
49
+ @test ! isnothing (value)
50
+ @test all (grad .!= nothing )
50
51
end
51
52
end
52
- end
53
+ end
0 commit comments