Skip to content

Commit 83902e3

Browse files
committed
feat: make diff compatibility with n-arity
1 parent 7b04a15 commit 83902e3

File tree

5 files changed

+80
-145
lines changed

5 files changed

+80
-145
lines changed

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,17 @@ module DynamicExpressionsZygoteExt
33
using Zygote: gradient
44
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGradient
55

6-
function _zygote_gradient(op::F, ::Val{1}) where {F}
7-
return ZygoteGradient{F,1,1}(op)
8-
end
9-
function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side}
10-
# side should be either nothing (for both), 1, or 2
11-
@assert side === nothing || side in (1, 2)
12-
return ZygoteGradient{F,2,side}(op)
6+
function _zygote_gradient(op::F, ::Val{degree}) where {F,degree}
7+
return ZygoteGradient{F,degree}(op)
138
end
149

15-
function (g::ZygoteGradient{F,1,1})(x) where {F}
10+
function (g::ZygoteGradient{F,1})(x) where {F}
1611
out = only(gradient(g.op, x))
1712
return out === nothing ? zero(x) : out
1813
end
19-
function (g::ZygoteGradient{F,2,nothing})(x, y) where {F}
20-
(∂x, ∂y) = gradient(g.op, x, y)
21-
return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y)
22-
end
23-
function (g::ZygoteGradient{F,2,1})(x, y) where {F}
24-
∂x = only(gradient(Base.Fix2(g.op, y), x))
25-
return ∂x === nothing ? zero(x) : ∂x
26-
end
27-
function (g::ZygoteGradient{F,2,2})(x, y) where {F}
28-
∂y = only(gradient(Base.Fix1(g.op, x), y))
29-
return ∂y === nothing ? zero(y) : ∂y
14+
function (g::ZygoteGradient{F,degree})(args::Vararg{Any,degree}) where {F,degree}
15+
partials = gradient(g.op, args...)
16+
return ntuple(i -> @something(partials[i], zero(args[i])), Val(degree))
3017
end
3118

3219
end

src/EvaluateDerivative.jl

Lines changed: 67 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module EvaluateDerivativeModule
22

3-
import ..NodeModule: AbstractExpressionNode, constructorof
3+
import ..NodeModule: AbstractExpressionNode, constructorof, get_children
44
import ..OperatorEnumModule: OperatorEnum
55
import ..UtilsModule: fill_similar, ResultOk2
66
import ..ValueInterfaceModule: is_valid_array
@@ -66,54 +66,18 @@ function eval_diff_tree_array(
6666
end
6767

6868
@generated function _eval_diff_tree_array(
69-
tree::AbstractExpressionNode{T},
69+
tree::AbstractExpressionNode{T,D},
7070
cX::AbstractMatrix{T},
7171
operators::OperatorEnum,
7272
direction::Integer,
73-
)::ResultOk2 where {T<:Number}
74-
nuna = get_nuna(operators)
75-
nbin = get_nbin(operators)
76-
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
77-
quote
78-
diff_deg1_eval(tree, cX, operators.unaops[op_idx], operators, direction)
79-
end
80-
else
81-
quote
82-
Base.Cartesian.@nif(
83-
$nuna,
84-
i -> i == op_idx,
85-
i ->
86-
diff_deg1_eval(tree, cX, operators.unaops[i], operators, direction)
87-
)
88-
end
89-
end
90-
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
91-
quote
92-
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
93-
end
94-
else
95-
quote
96-
Base.Cartesian.@nif(
97-
$nbin,
98-
i -> i == op_idx,
99-
i ->
100-
diff_deg2_eval(tree, cX, operators.binops[i], operators, direction)
101-
)
102-
end
103-
end
73+
)::ResultOk2 where {T<:Number,D}
10474
quote
105-
result = if tree.degree == 0
106-
diff_deg0_eval(tree, cX, direction)
107-
elseif tree.degree == 1
108-
op_idx = tree.op
109-
$deg1_branch
110-
else
111-
op_idx = tree.op
112-
$deg2_branch
113-
end
114-
!result.ok && return result
115-
return ResultOk2(
116-
result.x, result.dx, is_valid_array(result.x) && is_valid_array(result.dx)
75+
deg = tree.degree
76+
deg == 0 && return diff_deg0_eval(tree, cX, direction)
77+
Base.Cartesian.@nif(
78+
$D,
79+
i -> i == deg,
80+
i -> dispatch_diff_degn_eval(tree, cX, Val(i), operators, direction)
11781
)
11882
end
11983
end
@@ -130,58 +94,71 @@ function diff_deg0_eval(
13094
return ResultOk2(const_part, derivative_part, true)
13195
end
13296

133-
function diff_deg1_eval(
134-
tree::AbstractExpressionNode{T},
135-
cX::AbstractMatrix{T},
136-
op::F,
137-
operators::OperatorEnum,
138-
direction::Integer,
139-
) where {T<:Number,F}
140-
result = _eval_diff_tree_array(tree.l, cX, operators, direction)
141-
!result.ok && return result
142-
143-
# TODO - add type assertions to get better speed:
144-
cumulator = result.x
145-
dcumulator = result.dx
146-
diff_op = _zygote_gradient(op, Val(1))
147-
@inbounds @simd for j in eachindex(cumulator)
148-
x = op(cumulator[j])::T
149-
dx = diff_op(cumulator[j])::T * dcumulator[j]
150-
151-
cumulator[j] = x
152-
dcumulator[j] = dx
97+
@generated function diff_degn_eval(
98+
x_cumulators::NTuple{N}, dx_cumulators::NTuple{N}, op::F, direction::Integer
99+
) where {N,F}
100+
quote
101+
Base.Cartesian.@nexprs($N, i -> begin
102+
x_cumulator_i = x_cumulators[i]
103+
dx_cumulator_i = dx_cumulators[i]
104+
end)
105+
diff_op = _zygote_gradient(op, Val(N))
106+
@inbounds @simd for j in eachindex(x_cumulator_1)
107+
x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j])
108+
Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall(
109+
$N, diff_op, i -> x_cumulator_i[j]
110+
)
111+
dx = Base.Cartesian.@ncall($N, +, i -> grad_i * dx_cumulator_i[j])
112+
x_cumulator_1[j] = x
113+
dx_cumulator_1[j] = dx
114+
end
115+
return ResultOk2(x_cumulator_1, dx_cumulator_1, true)
153116
end
154-
return result
155117
end
156118

157-
function diff_deg2_eval(
158-
tree::AbstractExpressionNode{T},
119+
@generated function dispatch_diff_degn_eval(
120+
tree::AbstractExpressionNode{T,D},
159121
cX::AbstractMatrix{T},
160-
op::F,
161-
operators::OperatorEnum,
122+
::Val{degree},
123+
operators::OperatorEnum{OPS},
162124
direction::Integer,
163-
) where {T<:Number,F}
164-
result_l = _eval_diff_tree_array(tree.l, cX, operators, direction)
165-
!result_l.ok && return result_l
166-
result_r = _eval_diff_tree_array(tree.r, cX, operators, direction)
167-
!result_r.ok && return result_r
168-
169-
ar_l = result_l.x
170-
d_ar_l = result_l.dx
171-
ar_r = result_r.x
172-
d_ar_r = result_r.dx
173-
diff_op = _zygote_gradient(op, Val(2))
174-
175-
@inbounds @simd for j in eachindex(ar_l)
176-
x = op(ar_l[j], ar_r[j])::T
177-
178-
first, second = diff_op(ar_l[j], ar_r[j])::Tuple{T,T}
179-
dx = first * d_ar_l[j] + second * d_ar_r[j]
125+
) where {T<:Number,D,degree,OPS}
126+
nops = length(OPS.types[degree].types)
127+
128+
setup = quote
129+
cs = get_children(tree, Val($degree))
130+
Base.Cartesian.@nexprs(
131+
$degree,
132+
i -> begin
133+
result_i = _eval_diff_tree_array(cs[i], cX, operators, direction)
134+
!result_i.ok && return result_i
135+
end
136+
)
137+
x_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x)
138+
dx_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.dx)
139+
op_idx = tree.op
140+
end
180141

181-
ar_l[j] = x
182-
d_ar_l[j] = dx
142+
if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN
143+
quote
144+
$setup
145+
diff_degn_eval(
146+
x_cumulators, dx_cumulators, operators[$degree][op_idx], direction
147+
)
148+
end
149+
else
150+
quote
151+
$setup
152+
Base.Cartesian.@nif(
153+
$nops,
154+
i -> i == op_idx,
155+
i -> diff_degn_eval(
156+
x_cumulators, dx_cumulators, operators[$degree][i], direction
157+
)
158+
)
159+
end
183160
end
184-
return result_l
161+
# TODO: Need to add the case for many operators
185162
end
186163

187164
"""

src/ExtensionInterface.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,12 @@ function symbolic_to_node(args...; kws...)
77
return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.")
88
end
99

10-
struct ZygoteGradient{F,degree,arg} <: Function
10+
struct ZygoteGradient{F,degree} <: Function
1111
op::F
1212
end
1313

14-
function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg}
14+
function Base.show(io::IO, g::ZygoteGradient{F,degree}) where {F,degree}
1515
print(io, "")
16-
if degree == 2
17-
if arg == 1
18-
print(io, "")
19-
elseif arg == 2
20-
print(io, "")
21-
end
22-
end
2316
print(io, g.op)
2417
return nothing
2518
end

src/Node.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import ..OperatorEnumModule: AbstractOperatorEnum
66
import ..UtilsModule: deprecate_varmap, Undefined
77

88
const DEFAULT_NODE_TYPE = Float32
9+
const DEFAULT_MAX_DEGREE = 2
910

1011
"""
1112
AbstractNode{D}
1213
1314
Abstract type for D-arity trees. Must have the following fields:
1415
15-
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
16-
then `l` needs to be defined as the left child. If 2,
17-
then `r` also needs to be defined as the right child.
16+
- `degree::UInt8`: Degree of the node. This should be a value
17+
between 0 and `DEFAULT_MAX_DEGREE`.
1818
- `children`: A collection of D references to children nodes.
1919
2020
# Deprecated fields
@@ -25,7 +25,7 @@ Abstract type for D-arity trees. Must have the following fields:
2525
Don't use `nothing` to represent an undefined value
2626
as it will incur a large performance penalty.
2727
- `r::AbstractNode{D}`: Right child of the current node. Should only
28-
be defined if `degree == 2`.
28+
be defined if `degree >= 2`.
2929
"""
3030
abstract type AbstractNode{D} end
3131

@@ -82,7 +82,7 @@ for N in (:Node, :GraphNode)
8282
## Constructors:
8383
#################
8484
$N{_T,_D}() where {_T,_D} = new{_T,_D::Int}()
85-
$N{_T}() where {_T} = $N{_T,2}()
85+
$N{_T}() where {_T} = $N{_T,DEFAULT_MAX_DEGREE}()
8686
# TODO: Test with this disabled to spot any unintended uses
8787
end
8888
end
@@ -250,7 +250,6 @@ end
250250
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
251251
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
252252

253-
const DEFAULT_MAX_DEGREE = 2
254253
max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE
255254
max_degree(::Type{<:AbstractNode{D}}) where {D} = D
256255
max_degree(node::AbstractNode) = max_degree(typeof(node))

test/test_zygote_gradient_wrapper.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,8 @@
1111
g(x, y) = x * y
1212
@test repr(_zygote_gradient(g, Val(2))) == "∂g"
1313

14-
# Test binary gradient (first partial)
15-
@test repr(_zygote_gradient(g, Val(2), Val(1))) == "∂₁g"
16-
17-
# Test binary gradient (second partial)
18-
@test repr(_zygote_gradient(g, Val(2), Val(2))) == "∂₂g"
19-
2014
# Test with standard operators
2115
@test repr(_zygote_gradient(+, Val(2))) == "∂+"
22-
@test repr(_zygote_gradient(*, Val(2), Val(1))) == "∂₁*"
23-
@test repr(_zygote_gradient(*, Val(2), Val(2))) == "∂₂*"
24-
25-
first_partial = _zygote_gradient(log, Val(2), Val(1))
26-
nested = _zygote_gradient(first_partial, Val(1))
27-
@test repr(nested) == "∂∂₁log"
28-
29-
# Also should work with text/plain
30-
@test repr("text/plain", nested) == "∂∂₁log"
3116
end
3217

3318
@testitem "ZygoteGradient evaluation" begin
@@ -45,10 +30,4 @@ end
4530
# Test binary gradient (both partials)
4631
g(x, y) = x * y
4732
@test (_zygote_gradient(g, Val(2)))(x, y) == (3.0, 2.0)
48-
49-
# Test binary gradient (first partial)
50-
@test (_zygote_gradient(g, Val(2), Val(1)))(x, y) == 3.0
51-
52-
# Test second partial
53-
@test (_zygote_gradient(g, Val(2), Val(2)))(x, y) == 2.0
5433
end

0 commit comments

Comments
 (0)