Skip to content

Commit 1aee09e

Browse files
authored
Merge pull request #42 from TuringLang/dw/enzyme
Add Enzyme support
2 parents c2c4d8b + e7265af commit 1aee09e

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/AdvancedVI.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ function __init__()
7676
return out
7777
end
7878
end
79+
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
80+
include("compat/enzyme.jl")
81+
export EnzymeAD
82+
83+
function AdvancedVI.grad!(
84+
vo,
85+
alg::VariationalInference{<:AdvancedVI.EnzymeAD},
86+
q,
87+
model,
88+
θ::AbstractVector{<:Real},
89+
out::DiffResults.MutableDiffResult,
90+
args...
91+
)
92+
f(θ) = if (q isa Distribution)
93+
- vo(alg, update(q, θ), model, args...)
94+
else
95+
- vo(alg, q(θ), model, args...)
96+
end
97+
# Use `Enzyme.ReverseWithPrimal` once it is released:
98+
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
99+
y = f(θ)
100+
DiffResults.value!(out, y)
101+
dy = DiffResults.gradient(out)
102+
fill!(dy, 0)
103+
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
104+
return out
105+
end
106+
end
79107
end
80108

81109
export

src/compat/enzyme.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
struct EnzymeAD <: ADBackend end
2+
ADBackend(::Val{:enzyme}) = EnzymeAD
3+
function setadbackend(::Val{:enzyme})
4+
ADBACKEND[] = :enzyme
5+
end

0 commit comments

Comments
 (0)