Skip to content

Commit 7e0a1b4

Browse files
authored
Merge pull request #111 from SymbolicML/clean-up-zygote-gradients
Prettier printing for gradient operators
2 parents c73a705 + e27214a commit 7e0a1b4

File tree

5 files changed

+98
-12
lines changed

5 files changed

+98
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "1.5.1"
4+
version = "1.6.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
11
module DynamicExpressionsZygoteExt
22

3-
import Zygote: gradient
4-
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
3+
using Zygote: gradient
4+
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGradient
55

66
function _zygote_gradient(op::F, ::Val{1}) where {F}
7-
function (x)
8-
out = gradient(op, x)[1]
9-
return out === nothing ? zero(x) : out
10-
end
7+
return ZygoteGradient{F,1,1}(op)
118
end
12-
function _zygote_gradient(op::F, ::Val{2}) where {F}
13-
function (x, y)
14-
(∂x, ∂y) = gradient(op, x, y)
15-
return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y)
16-
end
9+
function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side}
10+
# side should be either nothing (for both), 1, or 2
11+
@assert side === nothing || side in (1, 2)
12+
return ZygoteGradient{F,2,side}(op)
13+
end
14+
15+
function (g::ZygoteGradient{F,1,1})(x) where {F}
16+
out = only(gradient(g.op, x))
17+
return out === nothing ? zero(x) : out
18+
end
19+
function (g::ZygoteGradient{F,2,nothing})(x, y) where {F}
20+
(∂x, ∂y) = gradient(g.op, x, y)
21+
return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y)
22+
end
23+
function (g::ZygoteGradient{F,2,1})(x, y) where {F}
24+
∂x = only(gradient(Base.Fix2(g.op, y), x))
25+
return ∂x === nothing ? zero(x) : ∂x
26+
end
27+
function (g::ZygoteGradient{F,2,2})(x, y) where {F}
28+
∂y = only(gradient(Base.Fix1(g.op, x), y))
29+
return ∂y === nothing ? zero(y) : ∂y
1730
end
1831

1932
end

src/ExtensionInterface.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@ function symbolic_to_node(args...; kws...)
77
return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.")
88
end
99

10+
struct ZygoteGradient{F,degree,arg} <: Function
11+
op::F
12+
end
13+
14+
function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg}
15+
print(io, "")
16+
if degree == 2
17+
if arg == 1
18+
print(io, "")
19+
elseif arg == 2
20+
print(io, "")
21+
end
22+
end
23+
print(io, g.op)
24+
return nothing
25+
end
26+
Base.show(io::IO, ::MIME"text/plain", g::ZygoteGradient) = show(io, g)
27+
1028
function _zygote_gradient(args...)
1129
return error("Please load the Zygote.jl package.")
1230
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
@testitem "ZygoteGradient string representation" begin
2+
using DynamicExpressions
3+
using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
4+
using Zygote
5+
6+
# Test unary gradient
7+
f(x) = x^2
8+
@test repr(_zygote_gradient(f, Val(1))) == "∂f"
9+
10+
# Test binary gradient (both partials)
11+
g(x, y) = x * y
12+
@test repr(_zygote_gradient(g, Val(2))) == "∂g"
13+
14+
# Test binary gradient (first partial)
15+
@test repr(_zygote_gradient(g, Val(2), Val(1))) == "∂₁g"
16+
17+
# Test binary gradient (second partial)
18+
@test repr(_zygote_gradient(g, Val(2), Val(2))) == "∂₂g"
19+
20+
# Test with standard operators
21+
@test repr(_zygote_gradient(+, Val(2))) == "∂+"
22+
@test repr(_zygote_gradient(*, Val(2), Val(1))) == "∂₁*"
23+
@test repr(_zygote_gradient(*, Val(2), Val(2))) == "∂₂*"
24+
25+
first_partial = _zygote_gradient(log, Val(2), Val(1))
26+
nested = _zygote_gradient(first_partial, Val(1))
27+
@test repr(nested) == "∂∂₁log"
28+
29+
# Also should work with text/plain
30+
@test repr("text/plain", nested) == "∂∂₁log"
31+
end
32+
33+
@testitem "ZygoteGradient evaluation" begin
34+
using DynamicExpressions
35+
using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
36+
using Zygote
37+
38+
x = 2.0
39+
y = 3.0
40+
41+
# Test unary gradient
42+
f(x) = x^2
43+
@test (_zygote_gradient(f, Val(1)))(x) == 4.0
44+
45+
# Test binary gradient (both partials)
46+
g(x, y) = x * y
47+
@test (_zygote_gradient(g, Val(2)))(x, y) == (3.0, 2.0)
48+
49+
# Test binary gradient (first partial)
50+
@test (_zygote_gradient(g, Val(2), Val(1)))(x, y) == 3.0
51+
52+
# Test second partial
53+
@test (_zygote_gradient(g, Val(2), Val(2)))(x, y) == 2.0
54+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,4 @@ include("test_node_interface.jl")
130130
include("test_expression_math.jl")
131131
include("test_structured_expression.jl")
132132
include("test_readonlynode.jl")
133+
include("test_zygote_gradient_wrapper.jl")

0 commit comments

Comments
 (0)