diff --git a/Project.toml b/Project.toml index 5f11cba3f..e0062c9ea 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index b733d810c..3857bfccd 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -43,13 +43,17 @@ chosen_combinations = [ ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), + ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), ("Dynamic", Models.dynamic(), :typed, :mooncake, true), + ("Dynamic", Models.dynamic(), :typed, :enzyme, true), ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), + ("Submodel", Models.parent(randn(rng)), :typed, :enzyme, true), ("LDA", lda_instance, :typed, :reversediff, true), + ("LDA", lda_instance, :typed, :enzyme, true), ] # Time running a model-like function that does not use DynamicPPL, as a reference point. diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 8c5032ace..195884ee8 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -13,7 +13,7 @@ using StableRNGs: StableRNG include("./Models.jl") using .Models: Models - +import Enzyme export Models, make_suite, model_dimension """ @@ -37,6 +37,7 @@ const SYMBOL_TO_BACKEND = Dict( :reversediff => ADTypes.AutoReverseDiff(; compile=false), :reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true), :mooncake => ADTypes.AutoMooncake(; config=nothing), + :enzyme => ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const), ) to_backend(x) = error("Unknown backend: $x")