Skip to content

Commit 63c7908

Browse files
committed
Test against Enzyme
1 parent d0510b1 commit 63c7908

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
9393

9494
| Exported symbol | Documentation | Description |
9595
|:----------------- |:------------------------------------ |:---------------------- |
96+
| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend |
9697
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
97-
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
9898
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |
99+
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
99100

100101
### Debugging
101102

src/Turing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Printf: Printf
2323
using Random: Random
2424
using LinearAlgebra: I
2525

26-
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake
26+
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme
2727

2828
const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()
2929

@@ -123,6 +123,7 @@ export
123123
AutoForwardDiff,
124124
AutoReverseDiff,
125125
AutoMooncake,
126+
AutoEnzyme,
126127
# Debugging - Turing
127128
setprogress!,
128129
# Distributions

test/ad.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ if INCLUDE_MOONCAKE
2020
using Mooncake: Mooncake
2121
end
2222

23+
const INCLUDE_ENZYME = !IS_PRERELEASE
24+
25+
if INCLUDE_ENZYME
26+
import Pkg
27+
Pkg.add("Enzyme")
28+
using Enzyme: Enzyme
29+
end
30+
2331
"""Element types that are always valid for a VarInfo regardless of ADType."""
2432
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
2533

@@ -193,6 +201,10 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
193201
if INCLUDE_MOONCAKE
194202
push!(ADTYPES, AutoMooncake(; config=nothing))
195203
end
204+
if INCLUDE_ENZYME
205+
push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward)))
206+
push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse)))
207+
end
196208

197209
# Check that ADTypeCheckContext itself works as expected.
198210
@testset "ADTypeCheckContext" begin

0 commit comments

Comments
 (0)