diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl index 1b59cad8..a00c864a 100644 --- a/ext/NormalizingFlowsEnzymeExt.jl +++ b/ext/NormalizingFlowsEnzymeExt.jl @@ -10,16 +10,14 @@ else using ..NormalizingFlows: ADTypes, DiffResults end -# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) function NormalizingFlows.value_and_gradient!( ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult ) where {T<:Real} - y = f(θ) - DiffResults.value!(out, y) ∇θ = DiffResults.gradient(out) fill!(∇θ, zero(T)) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + _, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + DiffResults.value!(out, y) return out end -end \ No newline at end of file +end diff --git a/test/ad.jl b/test/ad.jl index a394d806..aab6841b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -25,7 +25,7 @@ end ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # not working now + ADTypes.AutoEnzyme(), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -41,12 +41,20 @@ end out = DiffResults.GradientResult(θ) # check grad computation for elbo + # Enzyme needs a workaround + if at isa ADTypes.AutoEnzyme + activity = Enzyme.API.runtimeActivity() + Enzyme.API.runtimeActivity!(true) + end NormalizingFlows.grad!( Random.default_rng(), at, elbo, θ, re, out, logp, sample_per_iter ) + if at isa ADTypes.AutoEnzyme + Enzyme.API.runtimeActivity!(activity) + end @test DiffResults.value(out) != nothing @test all(DiffResults.gradient(out) .!= nothing) end end -end \ No newline at end of file +end