diff --git a/HISTORY.md b/HISTORY.md index 6251974075..62ca1d350c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -6,6 +6,12 @@ 0.37 removes the old Gibbs constructors deprecated in 0.36. +### Remove Zygote support + +Zygote is no longer officially supported as an automatic differentiation backend, and `AutoZygote` is no longer exported. You can continue to use Zygote by importing `AutoZygote` from ADTypes and it may well continue to work, but it is no longer tested and no effort will be expended to fix it if something breaks. + +[Mooncake](https://github.com/compintell/Mooncake.jl/) is the recommended replacement for Zygote. + ### DynamicPPL 0.35 Turing.jl v0.37 uses DynamicPPL v0.35, which brings with it several breaking changes: diff --git a/docs/src/api.md b/docs/src/api.md index afedf59bb7..3066a7fad9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au |:----------------- |:------------------------------------ |:---------------------- | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | | `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | -| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index abba580a27..d8fc09fb9d 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -106,7 +106,6 @@ export @model, # modelling externalsampler, AutoForwardDiff, # ADTypes AutoReverseDiff, - AutoZygote, AutoMooncake, setprogress!, # debugging Flat, diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index c04c7e862b..cfa064c651 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake using AdvancedPS: AdvancedPS @@ -20,7 +20,6 @@ include("container.jl") export @model, @varname, AutoForwardDiff, - AutoZygote, AutoReverseDiff, AutoMooncake, @logprob_str, diff --git a/test/Project.toml b/test/Project.toml index 96681f0491..489923767a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "5" @@ -75,5 +74,4 @@ StableRNGs = "1" StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" -Zygote = "0.5.4, 0.6" julia = "1.10" diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index ee544e7ce3..309276407a 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -8,7 +8,6 @@ using Mooncake: Mooncake using Test: Test using Turing: Turing using Turing: DynamicPPL -using Zygote: Zygote export ADTypeCheckContext, adbackends @@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict( ReverseDiff.TrackedVector, ), Turing.AutoMooncake => (Mooncake.CoDual,), - # Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the - # two by element type. However, we have other checks for Zygote, see check_adtype. - Turing.AutoZygote => (Zygote.Dual,), ) """ @@ -90,7 +86,6 @@ For instance, evaluating a model with would throw an error if within the model a type associated with e.g. ReverseDiff was encountered. -As a current short-coming, this context can not distinguish between ForwardDiff and Zygote. """ struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext @@ -134,21 +129,9 @@ end Check that the element types in `vi` are compatible with the ADType of `context`. -When Zygote is being used, we also more explicitly check that `adtype(context)` is -`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't -discriminate between the two based on element type alone. This function will still fail to -catch cases where Zygote is supposed to be used, but ForwardDiff is used instead. - -Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or -`WrongADBackendError` if Zygote is used unexpectedly. +Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. """ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) - Zygote.hook(vi) do _ - if !(adtype(context) <: Turing.AutoZygote) - throw(WrongADBackendError(Turing.AutoZygote, adtype(context))) - end - end - valids = valid_eltypes(context) for val in vi[:] valtype = typeof(val) diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index bf9f2b9b8d..243a80b881 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff using Test: @test, @testset, @test_throws using Turing: Turing using Turing: DynamicPPL -using Zygote: Zygote # Check that the ADTypeCheckContext works as expected. @testset "ADTypeCheckContext" begin @@ -16,20 +15,12 @@ using Zygote: Zygote adtypes = ( Turing.AutoForwardDiff(), Turing.AutoReverseDiff(), - Turing.AutoZygote(), # TODO: Mooncake # Turing.AutoMooncake(config=nothing), ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) for expected_adtype in adtypes - if ( - actual_adtype == Turing.AutoForwardDiff() && - expected_adtype == Turing.AutoZygote() - ) - # TODO(mhauru) We are currently unable to check this case. - continue - end contextualised_tm = DynamicPPL.contextualize( tm, ADTypeCheckContext(expected_adtype, tm.context) )