diff --git a/docs/src/utils.md b/docs/src/utils.md index b7e779185..25b5954e4 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -13,15 +13,17 @@ Zygote also provides a set of helpful utilities. These are all "user-level" tool in other words you could have written them easily yourself, but they live in Zygote for convenience. +See `ChainRules.ignore_derivatives` if you want to exclude some of your code from the +gradient calculation. This replaces previous Zygote-specific `ignore` and `dropgrad` +functionality. + ```@docs Zygote.withgradient Zygote.withjacobian Zygote.@showgrad Zygote.hook -Zygote.dropgrad Zygote.Buffer Zygote.forwarddiff -Zygote.ignore Zygote.checkpointed ``` diff --git a/src/Zygote.jl b/src/Zygote.jl index a42dd38c1..b1ca50aa9 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -18,6 +18,7 @@ export rrule_via_ad const Numeric{T<:Number} = Union{T, AbstractArray{<:T}} +include("deprecated.jl") include("tools/buffer.jl") include("tools/builtins.jl") diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 000000000..9bc808511 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,51 @@ +""" + dropgrad(x) -> x + +Drop the gradient of `x`. + + julia> gradient(2, 3) do a, b + dropgrad(a)*b + end + (nothing, 2) +""" +function dropgrad end + +@adjoint dropgrad(x) = dropgrad(x), _ -> nothing + +Base.@deprecate dropgrad(x) ChainRulesCore.ignore_derivatives(x) + + +""" + ignore() do + ... + end + +Tell Zygote to ignore a block of code. Everything inside the `do` block will run +on the forward pass as normal, but Zygote won't try to differentiate it at all. +This can be useful for e.g. code that does logging of the forward pass. + +Obviously, you run the risk of incorrect gradients if you use this incorrectly. +""" +function ignore end + +@adjoint ignore(f) = ignore(f), _ -> nothing + +Base.@deprecate ignore(f) ChainRulesCore.ignore_derivatives(f) + +""" + @ignore (...) + +Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`. +Example: + +```julia-repl +julia> f(x) = (y = Zygote.@ignore x; x * y); +julia> f'(1) +1 +``` +""" +macro ignore(ex) + return :(Zygote.ignore() do + $(esc(ex)) + end) +end diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 86e6fff8c..72c60a961 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -1,48 +1,3 @@ -""" - dropgrad(x) -> x - -Drop the gradient of `x`. - - julia> gradient(2, 3) do a, b - dropgrad(a)*b - end - (nothing, 2) -""" -dropgrad(x) = x -@adjoint dropgrad(x) = dropgrad(x), _ -> nothing - -""" - ignore() do - ... - end - -Tell Zygote to ignore a block of code. Everything inside the `do` block will run -on the forward pass as normal, but Zygote won't try to differentiate it at all. -This can be useful for e.g. code that does logging of the forward pass. - -Obviously, you run the risk of incorrect gradients if you use this incorrectly. -""" -ignore(f) = f() -@adjoint ignore(f) = ignore(f), _ -> nothing - -""" - @ignore (...) - -Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`. -Example: - -```julia-repl -julia> f(x) = (y = Zygote.@ignore x; x * y); -julia> f'(1) -1 -``` -""" -macro ignore(ex) - return :(Zygote.ignore() do - $(esc(ex)) - end) -end - """ hook(x̄ -> ..., x) -> x diff --git a/test/deprecated.jl b/test/deprecated.jl new file mode 100644 index 000000000..ffc4994c7 --- /dev/null +++ b/test/deprecated.jl @@ -0,0 +1,10 @@ +@test_deprecated dropgrad(1) +@test_deprecated ignore(1) +@test_deprecated Zygote.@ignore x=1 + +@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,) +@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,) +@test gradient(1) do x + y = Zygote.@ignore x + x * y +end == (1,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ac0dd28bf..268c1734e 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1681,14 +1681,6 @@ end @test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,) @test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,) @test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,) - - - @test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,) - @test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,) - @test gradient(1) do x - y = Zygote.@ignore x - x * y - end == (1,) end @testset "fastmath" begin diff --git a/test/runtests.jl b/test/runtests.jl index fe5590efd..565ad182f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,10 @@ using CUDA: has_cuda @warn "CUDA not found - Skipping CUDA Tests" end + @testset "deprecated.jl" begin + include("deprecated.jl") + end + @testset "Interface" begin include("interface.jl") end