Skip to content

Commit dcee3c0

Browse files
committed
rm unnecessary EnzymeCoreExt
1 parent 91202ff commit dcee3c0

File tree

6 files changed

+26
-53
lines changed

6 files changed

+26
-53
lines changed

Project.toml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1616

17-
[weakdeps]
18-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
19-
20-
[extensions]
21-
NormalizingFlowsEnzymeCoreExt = ["EnzymeCore", "ADTypes", "DifferentiationInterface"]
22-
2317
[compat]
2418
ADTypes = "1"
2519
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
@@ -30,6 +24,3 @@ Optimisers = "0.2.16, 0.3, 0.4"
3024
ProgressMeter = "1.0.0"
3125
StatsBase = "0.33, 0.34"
3226
julia = "1.10"
33-
34-
[extras]
35-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

ext/NormalizingFlowsEnzymeCoreExt.jl

Lines changed: 0 additions & 29 deletions
This file was deleted.

src/NormalizingFlows.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Optimisers
55
using LinearAlgebra, Random, Distributions, StatsBase
66
using ProgressMeter
77
using ADTypes
8-
using DifferentiationInterface
8+
import DifferentiationInterface as DI
99

1010
using DocStringExtensions
1111

src/optimize.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ function pm_next!(pm, stats::NamedTuple)
55
return ProgressMeter.next!(pm; showvalues=map(tuple, keys(stats), values(stats)))
66
end
77

8-
_wrap_in_DI_context(args) = map(DifferentiationInterface.Constant, args)
9-
108
function _prepare_gradient(loss, adbackend, θ, args...)
11-
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, map(DifferentiationInterface.Constant, args)...)
9+
return DI.prepare_gradient(loss, adbackend, θ, map(DI.Constant, args)...)
1210
end
1311

1412
function _value_and_gradient(loss, prep, adbackend, θ, args...)
15-
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...)
13+
return DI.value_and_gradient(loss, prep, adbackend, θ, map(DI.Constant, args)...)
1614
end
1715

1816

@@ -34,7 +32,7 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
3432
- `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised
3533
- `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters)
3634
- `re`: reconstruction function that maps the flattened parameters to the normalizing flow
37-
- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant)
35+
- `args...`: additional arguments for `loss` (will be set as DI.Constant)
3836
3937
# Keyword Arguments
4038
- `max_iters::Int=10000`: maximum number of iterations

test/ad.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
ADTypes.AutoForwardDiff(; chunksize=chunksize),
1313
ADTypes.AutoForwardDiff(),
1414
ADTypes.AutoReverseDiff(; compile=false),
15-
ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
15+
ADTypes.AutoEnzyme(;
16+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
17+
function_annotation=Enzyme.Const,
18+
),
1619
ADTypes.AutoMooncake(; config=Mooncake.Config()),
1720
]
1821
prep = NormalizingFlows._prepare_gradient(f, at, x, y, z)
@@ -27,33 +30,40 @@ end
2730
@testset "$at" for at in [
2831
ADTypes.AutoZygote(),
2932
ADTypes.AutoForwardDiff(),
30-
ADTypes.AutoReverseDiff(; compile = false),
31-
ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
33+
ADTypes.AutoReverseDiff(; compile=false),
34+
ADTypes.AutoEnzyme(;
35+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
36+
function_annotation=Enzyme.Const,
37+
),
3238
ADTypes.AutoMooncake(; config=Mooncake.Config()),
3339
]
3440
@testset "$T" for T in [Float32, Float64]
3541
μ = 10 * ones(T, 2)
3642
Σ = Diagonal(4 * ones(T, 2))
3743
target = MvNormal(μ, Σ)
3844
logp(z) = logpdf(target, z)
39-
45+
4046
# necessary for Zygote/mooncake to differentiate through the flow
4147
# prevent updating params of q0
42-
@leaf MvNormal
48+
@leaf MvNormal
4349
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
4450
flow = Bijectors.transformed(
4551
q₀, Bijectors.Shift(zeros(T, 2)) Bijectors.Scale(ones(T, 2))
4652
)
47-
53+
4854
θ, re = Optimisers.destructure(flow)
4955

5056
# check grad computation for elbo
51-
loss(θ, rng, logp, sample_per_iter) = -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter)
57+
function loss(θ, rng, logp, sample_per_iter)
58+
return -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter)
59+
end
5260

5361
rng = Random.default_rng()
5462
sample_per_iter = 10
5563

56-
prep = NormalizingFlows._prepare_gradient(loss, at, θ, rng, logp, sample_per_iter)
64+
prep = NormalizingFlows._prepare_gradient(
65+
loss, at, θ, rng, logp, sample_per_iter
66+
)
5767
value, grad = NormalizingFlows._value_and_gradient(
5868
loss, prep, at, θ, rng, logp, sample_per_iter
5969
)

test/interface.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
ADTypes.AutoForwardDiff(; chunksize=chunksize),
66
ADTypes.AutoForwardDiff(),
77
ADTypes.AutoReverseDiff(),
8-
ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
8+
ADTypes.AutoEnzyme(;
9+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
10+
function_annotation=Enzyme.Const,
11+
),
912
ADTypes.AutoMooncake(; config = Mooncake.Config()),
1013
]
1114
@testset "$T" for T in [Float32, Float64]

0 commit comments

Comments
 (0)