Skip to content

Commit be50700

Browse files
committed
Test against Enzyme
1 parent 296f654 commit be50700

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
9393

9494
| Exported symbol | Documentation | Description |
9595
|:----------------- |:------------------------------------ |:---------------------- |
96+
| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend |
9697
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
97-
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
9898
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |
99+
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
99100

100101
### Debugging
101102

src/Turing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Printf: Printf
2323
using Random: Random
2424
using LinearAlgebra: I
2525

26-
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake
26+
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme
2727

2828
const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()
2929

@@ -121,6 +121,7 @@ export
121121
AutoForwardDiff,
122122
AutoReverseDiff,
123123
AutoMooncake,
124+
AutoEnzyme,
124125
# Debugging - Turing
125126
setprogress!,
126127
# Distributions

test/ad.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ if INCLUDE_MOONCAKE
2020
using Mooncake: Mooncake
2121
end
2222

23+
const INCLUDE_ENZYME = !IS_PRERELEASE
24+
25+
if INCLUDE_ENZYME
26+
import Pkg
27+
Pkg.add("Enzyme")
28+
using Enzyme: Enzyme
29+
end
30+
2331
"""Element types that are always valid for a VarInfo regardless of ADType."""
2432
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
2533

@@ -191,6 +199,10 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
191199
if INCLUDE_MOONCAKE
192200
push!(ADTYPES, AutoMooncake(; config=nothing))
193201
end
202+
if INCLUDE_ENZYME
203+
push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward)))
204+
push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse)))
205+
end
194206

195207
# Check that ADTypeCheckContext itself works as expected.
196208
@testset "ADTypeCheckContext" begin

0 commit comments

Comments
 (0)