Skip to content

Commit b0390f7

Browse files
committed
fix _value_and_grad wrapper bug
1 parent 1970b09 commit b0390f7

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

src/optimize.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ function pm_next!(pm, stats::NamedTuple)
55
return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
66
end
77

8-
_wrap_in_DI_context(args...) = DifferentiationInterface.Constant.([args...])
8+
_wrap_in_DI_context(args) = DifferentiationInterface.Constant.([args...])
99

1010
function _prepare_gradient(loss, adbackend, θ, args...)
11-
if isempty(args...)
11+
if isempty(args)
1212
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ)
1313
end
1414
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...)
1515
end
1616

1717
function _value_and_gradient(loss, prep, adbackend, θ, args...)
18-
if isempty(args...)
18+
if isempty(args)
1919
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ)
2020
end
2121
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
66
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
9+
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
910
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

test/ad.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testset "DI.AD with context wrapper" begin
22
f(x, y, z) = sum(abs2, x .+ y .+ z)
3+
T = Float32
34

45
@testset "$T" for T in [Float32, Float64]
56
x = randn(T, 10)
@@ -11,9 +12,10 @@
1112
ADTypes.AutoZygote(),
1213
ADTypes.AutoForwardDiff(; chunksize=chunksize),
1314
ADTypes.AutoForwardDiff(),
14-
ADTypes.AutoReverseDiff(false),
15+
ADTypes.AutoReverseDiff(; false),
1516
ADTypes.AutoMooncake(; config=Mooncake.Config()),
1617
]
18+
at = ADTypes.AutoMooncake(; config=Mooncake.Config())
1719
prep = NormalizingFlows._prepare_gradient(f, at, x, y, z)
1820
value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z)
1921
@test DiffResults.value(out) f(x, y, z)
@@ -26,15 +28,18 @@ end
2628
@testset "$at" for at in [
2729
ADTypes.AutoZygote(),
2830
ADTypes.AutoForwardDiff(),
29-
ADTypes.AutoReverseDiff(false),
31+
ADTypes.AutoReverseDiff(; false),
3032
ADTypes.AutoMooncake(; config=Mooncake.Config()),
3133
]
3234
@testset "$T" for T in [Float32, Float64]
3335
μ = 10 * ones(T, 2)
3436
Σ = Diagonal(4 * ones(T, 2))
3537
target = MvNormal(μ, Σ)
3638
logp(z) = logpdf(target, z)
37-
39+
40+
# necessary for Zygote/mooncake to differentiate through the flow
41+
# prevent opt q0
42+
@leaf MvNormal
3843
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
3944
flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ)))
4045

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Bijectors, Optimisers
44
using LinearAlgebra
55
using Random
66
using ADTypes
7-
import DifferentiationInterface as DI
7+
# import DifferentiationInterface as DI
88
using ForwardDiff, Zygote, ReverseDiff, Mooncake
99
using Test
1010

0 commit comments

Comments
 (0)