Skip to content

Commit 3010669

Browse files
committed
update tests to new interface
1 parent 5ff6041 commit 3010669

File tree

5 files changed

+27
-24
lines changed

5 files changed

+27
-24
lines changed

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
4-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
4+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
55
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
6-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
76
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
87
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
99
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

test/ad.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
@testset "AD correctness" begin
2-
f(x) = sum(abs2, x)
1+
@testset "DI.AD with context wrapper" begin
2+
f(x, y, z) = sum(abs2, x .+ y .+ z)
33

44
@testset "$T" for T in [Float32, Float64]
55
x = randn(T, 10)
6+
y = randn(T, 10)
7+
z = randn(T, 10)
68
chunksize = size(x, 1)
79

810
@testset "$at" for at in [
911
ADTypes.AutoZygote(),
1012
ADTypes.AutoForwardDiff(; chunksize=chunksize),
1113
ADTypes.AutoForwardDiff(),
1214
ADTypes.AutoReverseDiff(false),
13-
ADTypes.AutoEnzyme(),
15+
ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()),
1416
]
15-
out = DiffResults.GradientResult(x)
16-
NormalizingFlows.value_and_gradient!(at, f, x, out)
17-
@test DiffResults.value(out) f(x)
18-
@test DiffResults.gradient(out) 2x
17+
value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z)
18+
@test DiffResults.value(out) f(x, y, z)
19+
@test DiffResults.gradient(out) 2 * (x .+ y .+ z)
1920
end
2021
end
2122
end
@@ -25,7 +26,7 @@ end
2526
ADTypes.AutoZygote(),
2627
ADTypes.AutoForwardDiff(),
2728
ADTypes.AutoReverseDiff(false),
28-
# ADTypes.AutoEnzyme(), # not working now
29+
ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()),
2930
]
3031
@testset "$T" for T in [Float32, Float64]
3132
μ = 10 * ones(T, 2)
@@ -38,15 +39,15 @@ end
3839

3940
sample_per_iter = 10
4041
θ, re = Optimisers.destructure(flow)
41-
out = DiffResults.GradientResult(θ)
4242

4343
# check grad computation for elbo
44-
NormalizingFlows.grad!(
45-
Random.default_rng(), at, elbo, θ, re, out, logp, sample_per_iter
44+
loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...)
45+
value, grad = NormalizingFlows._value_and_gradient(
46+
loss, at, θ, logp, randn(T, 2, sample_per_iter)
4647
)
4748

48-
@test DiffResults.value(out) != nothing
49-
@test all(DiffResults.gradient(out) .!= nothing)
49+
@test !isnothing(value)
50+
@test all(grad .!= nothing)
5051
end
5152
end
52-
end
53+
end

test/interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
ADTypes.AutoZygote(),
55
ADTypes.AutoForwardDiff(; chunksize=chunksize),
66
ADTypes.AutoForwardDiff(),
7-
ADTypes.AutoReverseDiff(false),
8-
# ADTypes.AutoEnzyme(), # doesn't work for Enzyme
7+
ADTypes.AutoReverseDiff(),
8+
ADTypes.AutoMooncake(; config = ADTypes.Mooncake.Config()),
99
]
1010
@testset "$T" for T in [Float32, Float64]
1111
μ = 10 * ones(T, 2)
@@ -44,4 +44,4 @@
4444
@test el_trained > -1
4545
end
4646
end
47-
end
47+
end

test/objectives.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
flow = Bijectors.transformed(q₀, Bijectors.Shift(μ) Bijectors.Scale(sqrt.(Σ)))
1010

1111
x = randn(T, 2)
12+
rng = Random.default_rng()
1213

1314
@testset "elbo" begin
14-
el = elbo(Random.default_rng(), flow, logp, 10)
15+
el = elbo(rng, flow, logp, 10)
1516

1617
@test abs(el) 1e-5
1718
@test logpdf(flow, x) + el logp(x)
@@ -20,8 +21,8 @@
2021
@testset "likelihood" begin
2122
sample_trained = rand(flow, 1000)
2223
sample_untrained = rand(q₀, 1000)
23-
llh_trained = NormalizingFlows.loglikelihood(flow, sample_trained)
24-
llh_untrained = NormalizingFlows.loglikelihood(flow, sample_untrained)
24+
llh_trained = NormalizingFlows.loglikelihood(rng, flow, sample_trained)
25+
llh_untrained = NormalizingFlows.loglikelihood(rng, flow, sample_untrained)
2526

2627
@test llh_trained > llh_untrained
2728
end

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ using Distributions
33
using Bijectors, Optimisers
44
using LinearAlgebra
55
using Random
6-
using ADTypes, DiffResults
6+
using ADTypes
7+
import DifferentiationInterface as DI
78
using ForwardDiff, Zygote, Enzyme, ReverseDiff
89
using Test
910

1011
include("ad.jl")
1112
include("objectives.jl")
12-
include("interface.jl")
13+
include("interface.jl")

0 commit comments

Comments
 (0)