diff --git a/docs/src/api.md b/docs/src/api.md index 0b8351eb3..604718b0e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -93,9 +93,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au | Exported symbol | Documentation | Description | |:----------------- |:------------------------------------ |:---------------------- | +| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | -| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | +| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index 1ff231017..a5d0a543c 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -23,7 +23,7 @@ using Printf: Printf using Random: Random using LinearAlgebra: I -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff() @@ -123,6 +123,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoMooncake, + AutoEnzyme, # Debugging - Turing setprogress!, # Distributions diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5..bb787425d 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -20,11 +20,19 @@ if INCLUDE_MOONCAKE using Mooncake: Mooncake end +const INCLUDE_ENZYME = !IS_PRERELEASE + +if INCLUDE_ENZYME + import Pkg + Pkg.add("Enzyme") + using Enzyme: Enzyme +end + """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) """A dictionary mapping ADTypes to the element types they use.""" -eltypes_by_adtype = Dict( +eltypes_by_adtype = Dict{Type,Tuple}( AutoForwardDiff => (ForwardDiff.Dual,), AutoReverseDiff => ( ReverseDiff.TrackedArray, @@ -39,6 +47,9 @@ eltypes_by_adtype = Dict( if INCLUDE_MOONCAKE eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,) end +if INCLUDE_ENZYME + eltypes_by_adtype[AutoEnzyme] = () +end """ AbstractWrongADBackendError @@ -193,6 +204,22 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)] if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end +if INCLUDE_ENZYME + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation=Enzyme.Const, + ), + ) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +end # Check that ADTypeCheckContext itself works as expected. @testset "ADTypeCheckContext" begin