Skip to content

Commit 5f977ab

Browse files
committed
feat: make grad compatible with n-arity
1 parent 3845539 commit 5f977ab

File tree

3 files changed

+77
-119
lines changed

3 files changed

+77
-119
lines changed

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ function _zygote_gradient(op::F, ::Val{degree}) where {F,degree}
77
return ZygoteGradient{F,degree}(op)
88
end
99

10-
function (g::ZygoteGradient{F,1})(x) where {F}
11-
out = only(gradient(g.op, x))
12-
return out === nothing ? zero(x) : out
13-
end
10+
# All this does is remove `nothing`, so that we get type stability
1411
function (g::ZygoteGradient{F,degree})(args::Vararg{Any,degree}) where {F,degree}
1512
partials = gradient(g.op, args...)
1613
return ntuple(i -> @something(partials[i], zero(args[i])), Val(degree))

src/EvaluateDerivative.jl

Lines changed: 75 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -254,55 +254,97 @@ function eval_grad_tree_array(
254254
end
255255

256256
@generated function _eval_grad_tree_array(
257-
tree::AbstractExpressionNode{T},
257+
tree::AbstractExpressionNode{T,D},
258258
n_gradients,
259-
index_tree::Union{NodeIndex,Nothing},
259+
index_tree::Union{NodeIndex{<:Any,D},Nothing},
260260
cX::AbstractMatrix{T},
261261
operators::OperatorEnum,
262262
::Val{mode},
263-
)::ResultOk2 where {T<:Number,mode}
264-
nuna = get_nuna(operators)
265-
nbin = get_nbin(operators)
266-
deg1_branch_skeleton = quote
267-
grad_deg1_eval(
268-
tree, n_gradients, index_tree, cX, operators.unaops[i], operators, Val(mode)
263+
)::ResultOk2 where {T<:Number,D,mode}
264+
quote
265+
deg = tree.degree
266+
deg == 0 && return grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode))
267+
Base.Cartesian.@nif(
268+
$D,
269+
i -> i == deg,
270+
i -> dispatch_grad_degn_eval(
271+
tree, n_gradients, index_tree, cX, Val(i), operators, Val(mode)
272+
)
269273
)
270274
end
271-
deg2_branch_skeleton = quote
272-
grad_deg2_eval(
273-
tree, n_gradients, index_tree, cX, operators.binops[i], operators, Val(mode)
275+
end
276+
277+
@generated function dispatch_grad_degn_eval(
278+
tree::AbstractExpressionNode{T},
279+
n_gradients,
280+
index_tree::Union{NodeIndex,Nothing},
281+
cX::AbstractMatrix{T},
282+
::Val{degree},
283+
operators::OperatorEnum{OPS},
284+
::Val{mode},
285+
) where {T<:Number,degree,OPS,mode}
286+
setup = quote
287+
cs = get_children(tree, Val($degree))
288+
index_cs =
289+
isnothing(index_tree) ? index_tree : get_children(index_tree, Val($degree))
290+
Base.Cartesian.@nexprs(
291+
$degree,
292+
i -> begin
293+
result_i = eval_grad_tree_array(
294+
cs[i],
295+
n_gradients,
296+
isnothing(index_cs) ? index_cs : index_cs[i],
297+
cX,
298+
operators,
299+
Val(mode),
300+
)
301+
!result_i.ok && return result_i
302+
end
274303
)
304+
x_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x)
305+
d_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.dx)
306+
op_idx = tree.op
275307
end
276-
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
277-
quote
278-
i = tree.op
279-
$deg1_branch_skeleton
280-
end
281-
else
282-
quote
283-
op_idx = tree.op
284-
Base.Cartesian.@nif($nuna, i -> i == op_idx, i -> $deg1_branch_skeleton)
285-
end
286-
end
287-
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
308+
nops = length(OPS.types[degree].types)
309+
if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN
288310
quote
289-
i = tree.op
290-
$deg2_branch_skeleton
311+
$setup
312+
grad_degn_eval(x_cumulators, d_cumulators, operators[$degree][op_idx])
291313
end
292314
else
293315
quote
294-
op_idx = tree.op
295-
Base.Cartesian.@nif($nbin, i -> i == op_idx, i -> $deg2_branch_skeleton)
316+
$setup
317+
Base.Cartesian.@nif(
318+
$nops,
319+
i -> i == op_idx,
320+
i -> grad_degn_eval(x_cumulators, d_cumulators, operators[$degree][i])
321+
)
296322
end
297323
end
324+
end
325+
326+
@generated function grad_degn_eval(
327+
x_cumulators::NTuple{N}, d_cumulators::NTuple{N}, op::F
328+
) where {N,F}
298329
quote
299-
if tree.degree == 0
300-
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode))
301-
elseif tree.degree == 1
302-
$deg1_branch
303-
else
304-
$deg2_branch
330+
Base.Cartesian.@nexprs($N, i -> begin
331+
x_cumulator_i = x_cumulators[i]
332+
d_cumulator_i = d_cumulators[i]
333+
end)
334+
diff_op = _zygote_gradient(op, Val($N))
335+
@inbounds @simd for j in eachindex(x_cumulator_1)
336+
x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j])
337+
Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall(
338+
$N, diff_op, i -> x_cumulator_i[j]
339+
)
340+
x_cumulator_1[j] = x
341+
for k in axes(d_cumulator_1, 1)
342+
d_cumulator_1[k, j] = Base.Cartesian.@ncall(
343+
$N, +, i -> grad_i * d_cumulator_i[k, j]
344+
)
345+
end
305346
end
347+
return ResultOk2(x_cumulator_1, d_cumulator_1, true)
306348
end
307349
end
308350

@@ -344,85 +386,4 @@ function grad_deg0_eval(
344386
return ResultOk2(const_part, derivative_part, true)
345387
end
346388

347-
function grad_deg1_eval(
348-
tree::AbstractExpressionNode{T},
349-
n_gradients,
350-
index_tree::Union{NodeIndex,Nothing},
351-
cX::AbstractMatrix{T},
352-
op::F,
353-
operators::OperatorEnum,
354-
::Val{mode},
355-
)::ResultOk2 where {T<:Number,F,mode}
356-
result = eval_grad_tree_array(
357-
tree.l,
358-
n_gradients,
359-
index_tree === nothing ? index_tree : index_tree.l,
360-
cX,
361-
operators,
362-
Val(mode),
363-
)
364-
!result.ok && return result
365-
366-
cumulator = result.x
367-
dcumulator = result.dx
368-
diff_op = _zygote_gradient(op, Val(1))
369-
@inbounds @simd for j in axes(dcumulator, 2)
370-
x = op(cumulator[j])::T
371-
dx = diff_op(cumulator[j])::T
372-
373-
cumulator[j] = x
374-
for k in axes(dcumulator, 1)
375-
dcumulator[k, j] = dx * dcumulator[k, j]
376-
end
377-
end
378-
return result
379-
end
380-
381-
function grad_deg2_eval(
382-
tree::AbstractExpressionNode{T},
383-
n_gradients,
384-
index_tree::Union{NodeIndex,Nothing},
385-
cX::AbstractMatrix{T},
386-
op::F,
387-
operators::OperatorEnum,
388-
::Val{mode},
389-
)::ResultOk2 where {T<:Number,F,mode}
390-
result_l = eval_grad_tree_array(
391-
tree.l,
392-
n_gradients,
393-
index_tree === nothing ? index_tree : index_tree.l,
394-
cX,
395-
operators,
396-
Val(mode),
397-
)
398-
!result_l.ok && return result_l
399-
result_r = eval_grad_tree_array(
400-
tree.r,
401-
n_gradients,
402-
index_tree === nothing ? index_tree : index_tree.r,
403-
cX,
404-
operators,
405-
Val(mode),
406-
)
407-
!result_r.ok && return result_r
408-
409-
cumulator_l = result_l.x
410-
dcumulator_l = result_l.dx
411-
cumulator_r = result_r.x
412-
dcumulator_r = result_r.dx
413-
diff_op = _zygote_gradient(op, Val(2))
414-
@inbounds @simd for j in axes(dcumulator_l, 2)
415-
c1 = cumulator_l[j]
416-
c2 = cumulator_r[j]
417-
x = op(c1, c2)::T
418-
dx1, dx2 = diff_op(c1, c2)::Tuple{T,T}
419-
cumulator_l[j] = x
420-
for k in axes(dcumulator_l, 1)
421-
dcumulator_l[k, j] = dx1 * dcumulator_l[k, j] + dx2 * dcumulator_r[k, j]
422-
end
423-
end
424-
425-
return result_l
426-
end
427-
428389
end

test/test_zygote_gradient_wrapper.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525

2626
# Test unary gradient
2727
f(x) = x^2
28-
@test (_zygote_gradient(f, Val(1)))(x) == 4.0
28+
@test (_zygote_gradient(f, Val(1)))(x) == (4.0,)
2929

3030
# Test binary gradient (both partials)
3131
g(x, y) = x * y

0 commit comments

Comments
 (0)