Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ Zygote also provides a set of helpful utilities. These are all "user-level" tool
in other words you could have written them easily yourself, but they live in
Zygote for convenience.

See `ChainRules.ignore_derivatives` if you want to exclude some of your code from the
gradient calculation. This replaces previous Zygote-specific `ignore` and `dropgrad`
functionality.

```@docs
Zygote.withgradient
Zygote.withjacobian
Zygote.@showgrad
Zygote.hook
Zygote.dropgrad
Zygote.Buffer
Zygote.forwarddiff
Zygote.ignore
Zygote.checkpointed
```

Expand Down
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export rrule_via_ad

const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}

include("deprecated.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")

Expand Down
51 changes: 51 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
dropgrad(x) -> x

Drop the gradient of `x`.

julia> gradient(2, 3) do a, b
dropgrad(a)*b
end
(nothing, 2)
"""
function dropgrad end

@adjoint dropgrad(x) = dropgrad(x), _ -> nothing

Base.@deprecate dropgrad(x) ChainRulesCore.ignore_derivatives(x)


"""
ignore() do
...
end

Tell Zygote to ignore a block of code. Everything inside the `do` block will run
on the forward pass as normal, but Zygote won't try to differentiate it at all.
This can be useful for e.g. code that does logging of the forward pass.

Obviously, you run the risk of incorrect gradients if you use this incorrectly.
"""
function ignore end

@adjoint ignore(f) = ignore(f), _ -> nothing

Base.@deprecate ignore(f) ChainRulesCore.ignore_derivatives(f)

"""
@ignore (...)

Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
Example:

```julia-repl
julia> f(x) = (y = Zygote.@ignore x; x * y);
julia> f'(1)
1
```
"""
macro ignore(ex)
return :(Zygote.ignore() do
$(esc(ex))
end)
end
45 changes: 0 additions & 45 deletions src/lib/utils.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,3 @@
"""
dropgrad(x) -> x

Drop the gradient of `x`.

julia> gradient(2, 3) do a, b
dropgrad(a)*b
end
(nothing, 2)
"""
dropgrad(x) = x
@adjoint dropgrad(x) = dropgrad(x), _ -> nothing

"""
ignore() do
...
end

Tell Zygote to ignore a block of code. Everything inside the `do` block will run
on the forward pass as normal, but Zygote won't try to differentiate it at all.
This can be useful for e.g. code that does logging of the forward pass.

Obviously, you run the risk of incorrect gradients if you use this incorrectly.
"""
ignore(f) = f()
@adjoint ignore(f) = ignore(f), _ -> nothing

"""
@ignore (...)

Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
Example:

```julia-repl
julia> f(x) = (y = Zygote.@ignore x; x * y);
julia> f'(1)
1
```
"""
macro ignore(ex)
return :(Zygote.ignore() do
$(esc(ex))
end)
end

"""
hook(x̄ -> ..., x) -> x

Expand Down
10 changes: 10 additions & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@test_deprecated dropgrad(1)
@test_deprecated ignore(1)
@test_deprecated Zygote.@ignore x=1

@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
@test gradient(1) do x
y = Zygote.@ignore x
x * y
end == (1,)
8 changes: 0 additions & 8 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1681,14 +1681,6 @@ end
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)


@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
@test gradient(1) do x
y = Zygote.@ignore x
x * y
end == (1,)
end

@testset "fastmath" begin
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ using CUDA: has_cuda
@warn "CUDA not found - Skipping CUDA Tests"
end

@testset "deprecated.jl" begin
include("deprecated.jl")
end

@testset "Interface" begin
include("interface.jl")
end
Expand Down