Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "1.5.1"
version = "1.6.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
35 changes: 24 additions & 11 deletions ext/DynamicExpressionsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
module DynamicExpressionsZygoteExt

import Zygote: gradient
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
using Zygote: gradient
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGradient

function _zygote_gradient(op::F, ::Val{1}) where {F}
function (x)
out = gradient(op, x)[1]
return out === nothing ? zero(x) : out
end
return ZygoteGradient{F,1,1}(op)
end
function _zygote_gradient(op::F, ::Val{2}) where {F}
function (x, y)
(∂x, ∂y) = gradient(op, x, y)
return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y)
end
function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side}
# side should be either nothing (for both), 1, or 2
@assert side === nothing || side in (1, 2)
return ZygoteGradient{F,2,side}(op)
end

function (g::ZygoteGradient{F,1,1})(x) where {F}
out = only(gradient(g.op, x))
return out === nothing ? zero(x) : out
end
function (g::ZygoteGradient{F,2,nothing})(x, y) where {F}
(∂x, ∂y) = gradient(g.op, x, y)
return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y)
end
function (g::ZygoteGradient{F,2,1})(x, y) where {F}
∂x = only(gradient(Base.Fix2(g.op, y), x))
return ∂x === nothing ? zero(x) : ∂x
end
function (g::ZygoteGradient{F,2,2})(x, y) where {F}
∂y = only(gradient(Base.Fix1(g.op, x), y))
return ∂y === nothing ? zero(y) : ∂y
end

end
18 changes: 18 additions & 0 deletions src/ExtensionInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ function symbolic_to_node(args...; kws...)
return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.")
end

struct ZygoteGradient{F,degree,arg} <: Function
op::F
end

function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg}
print(io, "∂")
if degree == 2
if arg == 1
print(io, "₁")
elseif arg == 2
print(io, "₂")
end
end
print(io, g.op)
return nothing
end
Base.show(io::IO, ::MIME"text/plain", g::ZygoteGradient) = show(io, g)

function _zygote_gradient(args...)
return error("Please load the Zygote.jl package.")
end
Expand Down
54 changes: 54 additions & 0 deletions test/test_zygote_gradient_wrapper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
@testitem "ZygoteGradient string representation" begin
using DynamicExpressions
using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
using Zygote

# Test unary gradient
f(x) = x^2
@test repr(_zygote_gradient(f, Val(1))) == "∂f"

# Test binary gradient (both partials)
g(x, y) = x * y
@test repr(_zygote_gradient(g, Val(2))) == "∂g"

# Test binary gradient (first partial)
@test repr(_zygote_gradient(g, Val(2), Val(1))) == "∂₁g"

# Test binary gradient (second partial)
@test repr(_zygote_gradient(g, Val(2), Val(2))) == "∂₂g"

# Test with standard operators
@test repr(_zygote_gradient(+, Val(2))) == "∂+"
@test repr(_zygote_gradient(*, Val(2), Val(1))) == "∂₁*"
@test repr(_zygote_gradient(*, Val(2), Val(2))) == "∂₂*"

first_partial = _zygote_gradient(log, Val(2), Val(1))
nested = _zygote_gradient(first_partial, Val(1))
@test repr(nested) == "∂∂₁log"

# Also should work with text/plain
@test repr("text/plain", nested) == "∂∂₁log"
end

@testitem "ZygoteGradient evaluation" begin
using DynamicExpressions
using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
using Zygote

x = 2.0
y = 3.0

# Test unary gradient
f(x) = x^2
@test (_zygote_gradient(f, Val(1)))(x) == 4.0

# Test binary gradient (both partials)
g(x, y) = x * y
@test (_zygote_gradient(g, Val(2)))(x, y) == (3.0, 2.0)

# Test binary gradient (first partial)
@test (_zygote_gradient(g, Val(2), Val(1)))(x, y) == 3.0

# Test second partial
@test (_zygote_gradient(g, Val(2), Val(2)))(x, y) == 2.0
end
1 change: 1 addition & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ include("test_node_interface.jl")
include("test_expression_math.jl")
include("test_structured_expression.jl")
include("test_readonlynode.jl")
include("test_zygote_gradient_wrapper.jl")
Loading