Skip to content

Commit a4d0ad4

Browse files
authored
Merge pull request #1245 from mzgubic/mz/deprecate
deprecate `dropgrad` and `ignore`
2 parents 3239330 + 4538914 commit a4d0ad4

File tree

7 files changed

+70
-55
lines changed

7 files changed

+70
-55
lines changed

docs/src/utils.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ Zygote also provides a set of helpful utilities. These are all "user-level" tool
1313
in other words you could have written them easily yourself, but they live in
1414
Zygote for convenience.
1515

16+
See `ChainRules.ignore_derivatives` if you want to exclude some of your code from the
17+
gradient calculation. This replaces previous Zygote-specific `ignore` and `dropgrad`
18+
functionality.
19+
1620
```@docs
1721
Zygote.withgradient
1822
Zygote.withjacobian
1923
Zygote.@showgrad
2024
Zygote.hook
21-
Zygote.dropgrad
2225
Zygote.Buffer
2326
Zygote.forwarddiff
24-
Zygote.ignore
2527
Zygote.checkpointed
2628
```
2729

src/Zygote.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ export rrule_via_ad
1818

1919
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
2020

21+
include("deprecated.jl")
2122
include("tools/buffer.jl")
2223
include("tools/builtins.jl")
2324

src/deprecated.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
dropgrad(x) -> x
3+
4+
Drop the gradient of `x`.
5+
6+
julia> gradient(2, 3) do a, b
7+
dropgrad(a)*b
8+
end
9+
(nothing, 2)
10+
"""
11+
function dropgrad end
12+
13+
@adjoint dropgrad(x) = dropgrad(x), _ -> nothing
14+
15+
Base.@deprecate dropgrad(x) ChainRulesCore.ignore_derivatives(x)
16+
17+
18+
"""
19+
ignore() do
20+
...
21+
end
22+
23+
Tell Zygote to ignore a block of code. Everything inside the `do` block will run
24+
on the forward pass as normal, but Zygote won't try to differentiate it at all.
25+
This can be useful for e.g. code that does logging of the forward pass.
26+
27+
Obviously, you run the risk of incorrect gradients if you use this incorrectly.
28+
"""
29+
function ignore end
30+
31+
@adjoint ignore(f) = ignore(f), _ -> nothing
32+
33+
Base.@deprecate ignore(f) ChainRulesCore.ignore_derivatives(f)
34+
35+
"""
36+
@ignore (...)
37+
38+
Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
39+
Example:
40+
41+
```julia-repl
42+
julia> f(x) = (y = Zygote.@ignore x; x * y);
43+
julia> f'(1)
44+
1
45+
```
46+
"""
47+
macro ignore(ex)
48+
return :(Zygote.ignore() do
49+
$(esc(ex))
50+
end)
51+
end

src/lib/utils.jl

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,3 @@
1-
"""
2-
dropgrad(x) -> x
3-
4-
Drop the gradient of `x`.
5-
6-
julia> gradient(2, 3) do a, b
7-
dropgrad(a)*b
8-
end
9-
(nothing, 2)
10-
"""
11-
dropgrad(x) = x
12-
@adjoint dropgrad(x) = dropgrad(x), _ -> nothing
13-
14-
"""
15-
ignore() do
16-
...
17-
end
18-
19-
Tell Zygote to ignore a block of code. Everything inside the `do` block will run
20-
on the forward pass as normal, but Zygote won't try to differentiate it at all.
21-
This can be useful for e.g. code that does logging of the forward pass.
22-
23-
Obviously, you run the risk of incorrect gradients if you use this incorrectly.
24-
"""
25-
ignore(f) = f()
26-
@adjoint ignore(f) = ignore(f), _ -> nothing
27-
28-
"""
29-
@ignore (...)
30-
31-
Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
32-
Example:
33-
34-
```julia-repl
35-
julia> f(x) = (y = Zygote.@ignore x; x * y);
36-
julia> f'(1)
37-
1
38-
```
39-
"""
40-
macro ignore(ex)
41-
return :(Zygote.ignore() do
42-
$(esc(ex))
43-
end)
44-
end
45-
461
"""
472
hook(x̄ -> ..., x) -> x
483

test/deprecated.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@test_deprecated dropgrad(1)
2+
@test_deprecated ignore(1)
3+
@test_deprecated Zygote.@ignore x=1
4+
5+
@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
6+
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
7+
@test gradient(1) do x
8+
y = Zygote.@ignore x
9+
x * y
10+
end == (1,)

test/gradcheck.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,14 +1681,6 @@ end
16811681
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
16821682
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
16831683
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)
1684-
1685-
1686-
@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
1687-
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
1688-
@test gradient(1) do x
1689-
y = Zygote.@ignore x
1690-
x * y
1691-
end == (1,)
16921684
end
16931685

16941686
@testset "fastmath" begin

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ using CUDA: has_cuda
1414
@warn "CUDA not found - Skipping CUDA Tests"
1515
end
1616

17+
@testset "deprecated.jl" begin
18+
include("deprecated.jl")
19+
end
20+
1721
@testset "Interface" begin
1822
include("interface.jl")
1923
end

0 commit comments

Comments
 (0)