Skip to content

Commit b3487b5

Browse files
committed
fixing extension loading error
1 parent 9a8ed04 commit b3487b5

File tree

9 files changed

+117
-44
lines changed

9 files changed

+117
-44
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1616
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1717

1818
[weakdeps]
19-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
19+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2020

2121
[extensions]
22-
NormalizingFlowsEnzymeExt = "Enzyme"
22+
NormalizingFlowsEnzymeCoreExt = "EnzymeCore"
2323

2424
[compat]
2525
ADTypes = "1"
@@ -32,3 +32,6 @@ ProgressMeter = "1.0.0"
3232
Requires = "1"
3333
StatsBase = "0.33, 0.34"
3434
julia = "1.10"
35+
36+
[extras]
37+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module NormalizingFlowsEnzymeCoreExt
2+
3+
using EnzymeCore
4+
using NormalizingFlows
5+
using NormalizingFlows: ADTypes, DifferentiationInterface
6+
7+
# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012
8+
function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...)
9+
if isempty(args)
10+
return DifferentiationInterface.prepare_gradient(
11+
EnzymeCore.Const(loss), adbackend, θ
12+
)
13+
end
14+
return DifferentiationInterface.prepare_gradient(
15+
EnzymeCore.Const(loss),
16+
adbackend,
17+
θ,
18+
map(DifferentiationInterface.Constant, args)...,
19+
)
20+
end
21+
22+
function NormalizingFlows._value_and_gradient(
23+
loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...
24+
)
25+
if isempty(args)
26+
return DifferentiationInterface.value_and_gradient(
27+
EnzymeCore.Const(loss), prep, adbackend, θ
28+
)
29+
end
30+
return DifferentiationInterface.value_and_gradient(
31+
EnzymeCore.Const(loss),
32+
prep,
33+
adbackend,
34+
θ,
35+
map(DifferentiationInterface.Constant, args)...,
36+
)
37+
end
38+
39+
end

ext/NormalizingFlowsEnzymeExt.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/NormalizingFlows.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,17 @@ include("optimize.jl")
7979
include("objectives.jl")
8080

8181

82-
# optional dependencies
83-
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
82+
if !isdefined(Base, :get_extension)
8483
using Requires
8584
end
8685

87-
# Question: should Exts be loaded here or in train.jl?
88-
function __init__()
89-
@static if !isdefined(Base, :get_extension)
90-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
91-
"../ext/NormalizingFlowsEnzymeExt.jl"
86+
87+
@static if !isdefined(Base, :get_extension)
88+
function __init__()
89+
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
90+
joinpath(@__DIR__, "../ext/NormalizingFlowsEnzymeCoreExt.jl")
9291
)
9392
end
9493
end
9594

96-
9795
end

src/optimize.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function optimize(
9393
converged = false
9494
i = 1
9595
while (i max_iters) && !converged
96-
# ls, g = DifferentiationInterface.value_and_gradient!(loss, grad, prep, adbackend, θ)
9796
ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...)
9897

9998
# Save stats

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
44
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
55
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
66
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
7+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/ad.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
1616
ADTypes.AutoMooncake(; config=Mooncake.Config()),
1717
]
18-
# at = ADTypes.AutoMooncake(; config=Mooncake.Config())
1918
prep = NormalizingFlows._prepare_gradient(f, at, x, y, z)
2019
value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z)
2120
@test value f(x, y, z)
@@ -39,11 +38,13 @@ end
3938
logp(z) = logpdf(target, z)
4039

4140
# necessary for Zygote/mooncake to differentiate through the flow
42-
# prevent opt q0
41+
# prevent updating params of q0
4342
@leaf MvNormal
4443
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
45-
flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ)))
46-
44+
flow = Bijectors.transformed(
45+
q₀, Bijectors.Shift(zeros(T, 2)) Bijectors.Scale(ones(T, 2))
46+
)
47+
4748
θ, re = Optimisers.destructure(flow)
4849

4950
# check grad computation for elbo

test/interface.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "learining 2d Gaussian" begin
1+
@testset "testing mean-field Gaussian VI" begin
22
chunksize = 4
33
@testset "$adtype" for adtype in [
44
ADTypes.AutoZygote(),
@@ -48,3 +48,61 @@
4848
end
4949
end
5050
end
51+
52+
# function create_planar_flow(n_layers::Int, q₀, T)
53+
# d = length(q₀)
54+
# if T == Float32
55+
# Ls = reduce(∘, [f32(PlanarLayer(d)) for _ in 1:n_layers])
56+
# else
57+
# Ls = reduce(∘, [PlanarLayer(d) for _ in 1:n_layers])
58+
# end
59+
# return Bijectors.transformed(q₀, Ls)
60+
# end
61+
62+
# @testset "testing planar flow" begin
63+
# chunksize = 4
64+
# @testset "$adtype" for adtype in [
65+
# ADTypes.AutoZygote(),
66+
# ADTypes.AutoForwardDiff(; chunksize=chunksize),
67+
# ADTypes.AutoForwardDiff(),
68+
# ADTypes.AutoReverseDiff(),
69+
# ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
70+
# # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64
71+
# ]
72+
# @testset "$T" for T in [Float32, Float64]
73+
# μ = 10 * ones(T, 2)
74+
# Σ = Diagonal(4 * ones(T, 2))
75+
76+
# target = MvNormal(μ, Σ)
77+
# logp(z) = logpdf(target, z)
78+
79+
# @leaf MvNormal
80+
# q₀ = MvNormal(zeros(T, 2), ones(T, 2))
81+
# nlayers = 10
82+
# flow = create_planar_flow(nlayers, q₀, T)
83+
84+
# sample_per_iter = 10
85+
# cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
86+
# checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
87+
# flow_trained, stats, _, _ = train_flow(
88+
# elbo,
89+
# flow,
90+
# logp,
91+
# sample_per_iter;
92+
# max_iters=10_000,
93+
# optimiser=Optimisers.Adam(one(T)/100),
94+
# ADbackend=adtype,
95+
# show_progress=false,
96+
# callback=cb,
97+
# hasconverged=checkconv,
98+
# )
99+
# θ, re = Optimisers.destructure(flow_trained)
100+
101+
# el_untrained = elbo(flow, logp, 1000)
102+
# el_trained = elbo(flow_trained, logp, 1000)
103+
104+
# @test el_trained > el_untrained
105+
# @test el_trained > -1
106+
# end
107+
# end
108+
# end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using Random
66
using ADTypes
77
using Functors
88
using ForwardDiff, Zygote, ReverseDiff, Enzyme, Mooncake
9+
using Flux: f32
910
import DifferentiationInterface as DI
11+
1012
using Test
1113

1214
include("ad.jl")

0 commit comments

Comments
 (0)