Skip to content

Commit 1b36c6e

Browse files
authored
Fix Enzyme extension (#79)
* Fix Enzyme extension * Enable Enzyme tests * Fix format * Fix format * Do not test on Julia nightly
1 parent 94cdd07 commit 1b36c6e

File tree

6 files changed

+19
-45
lines changed

6 files changed

+19
-45
lines changed

.github/workflows/JuliaNightly.yml

Lines changed: 0 additions & 33 deletions
This file was deleted.

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.org/AdvancedVI.jl/stable/)
22
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.org/AdvancedVI.jl/dev/)
33
[![Build Status](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml?query=branch%3Amaster)
4-
[![JuliaNightly](https://github.com/TuringLang/AdvancedVI.jl/workflows/JuliaNightly/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions?query=workflow%3AJuliaNightly+branch%3Amaster)
54
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)
65

76
# AdvancedVI.jl

ext/AdvancedVIEnzymeExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@ else
1111
using ..AdvancedVI: ADTypes, DiffResults
1212
end
1313

14-
# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
1514
function AdvancedVI.value_and_gradient!(
16-
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
15+
ad::ADTypes.AutoEnzyme,
16+
f,
17+
θ::AbstractVector{T},
18+
out::DiffResults.MutableDiffResult,
1719
) where {T<:Real}
18-
y = f(θ)
19-
DiffResults.value!(out, y)
2020
∇θ = DiffResults.gradient(out)
2121
fill!(∇θ, zero(T))
22-
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
22+
_, y = Enzyme.autodiff(
23+
Enzyme.ReverseWithPrimal,
24+
f,
25+
Enzyme.Active,
26+
Enzyme.Duplicated(θ, ∇θ),
27+
)
28+
DiffResults.value!(out, y)
2329
return out
2430
end
2531

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
44
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
55
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
66
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
7+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
78
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -26,6 +27,7 @@ ADTypes = "0.2.1, 1"
2627
Bijectors = "0.13"
2728
Distributions = "0.25.100"
2829
DistributionsAD = "0.6.45"
30+
Enzyme = "0.12"
2931
FillArrays = "1.6.1"
3032
ForwardDiff = "0.10.36"
3133
Functors = "0.4.5"

test/interface/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ using Test
33

44
@testset "ad" begin
55
@testset "$(adname)" for (adname, adsymbol) Dict(
6-
:ForwardDiff => AutoForwardDiff(),
7-
:ReverseDiff => AutoReverseDiff(),
8-
:Zygote => AutoZygote(),
9-
# :Enzyme => AutoEnzyme() # Currently not tested against
10-
)
6+
:ForwardDiff => AutoForwardDiff(),
7+
:ReverseDiff => AutoReverseDiff(),
8+
:Zygote => AutoZygote(),
9+
:Enzyme => AutoEnzyme(),
10+
)
1111
D = 10
1212
A = randn(D, D)
1313
λ = randn(D)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using DistributionsAD
1818
using LogDensityProblems
1919
using Optimisers
2020
using ADTypes
21-
using ForwardDiff, ReverseDiff, Zygote
21+
using Enzyme, ForwardDiff, ReverseDiff, Zygote
2222

2323
using AdvancedVI
2424

0 commit comments

Comments
 (0)