@@ -254,55 +254,97 @@ function eval_grad_tree_array(
254254end
255255
256256@generated function _eval_grad_tree_array (
257- tree:: AbstractExpressionNode{T} ,
257+ tree:: AbstractExpressionNode{T,D } ,
258258 n_gradients,
259- index_tree:: Union{NodeIndex,Nothing} ,
259+ index_tree:: Union{NodeIndex{<:Any,D} ,Nothing} ,
260260 cX:: AbstractMatrix{T} ,
261261 operators:: OperatorEnum ,
262262 :: Val{mode} ,
263- ):: ResultOk2 where {T<: Number ,mode}
264- nuna = get_nuna (operators)
265- nbin = get_nbin (operators)
266- deg1_branch_skeleton = quote
267- grad_deg1_eval (
268- tree, n_gradients, index_tree, cX, operators. unaops[i], operators, Val (mode)
263+ ):: ResultOk2 where {T<: Number ,D,mode}
264+ quote
265+ deg = tree. degree
266+ deg == 0 && return grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (mode))
267+ Base. Cartesian. @nif (
268+ $ D,
269+ i -> i == deg,
270+ i -> dispatch_grad_degn_eval (
271+ tree, n_gradients, index_tree, cX, Val (i), operators, Val (mode)
272+ )
269273 )
270274 end
271- deg2_branch_skeleton = quote
272- grad_deg2_eval (
273- tree, n_gradients, index_tree, cX, operators. binops[i], operators, Val (mode)
275+ end
276+
277+ @generated function dispatch_grad_degn_eval (
278+ tree:: AbstractExpressionNode{T} ,
279+ n_gradients,
280+ index_tree:: Union{NodeIndex,Nothing} ,
281+ cX:: AbstractMatrix{T} ,
282+ :: Val{degree} ,
283+ operators:: OperatorEnum{OPS} ,
284+ :: Val{mode} ,
285+ ) where {T<: Number ,degree,OPS,mode}
286+ setup = quote
287+ cs = get_children (tree, Val ($ degree))
288+ index_cs =
289+ isnothing (index_tree) ? index_tree : get_children (index_tree, Val ($ degree))
290+ Base. Cartesian. @nexprs (
291+ $ degree,
292+ i -> begin
293+ result_i = eval_grad_tree_array (
294+ cs[i],
295+ n_gradients,
296+ isnothing (index_cs) ? index_cs : index_cs[i],
297+ cX,
298+ operators,
299+ Val (mode),
300+ )
301+ ! result_i. ok && return result_i
302+ end
274303 )
304+ x_cumulators = Base. Cartesian. @ntuple ($ degree, i -> result_i. x)
305+ d_cumulators = Base. Cartesian. @ntuple ($ degree, i -> result_i. dx)
306+ op_idx = tree. op
275307 end
276- deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
277- quote
278- i = tree. op
279- $ deg1_branch_skeleton
280- end
281- else
282- quote
283- op_idx = tree. op
284- Base. Cartesian. @nif ($ nuna, i -> i == op_idx, i -> $ deg1_branch_skeleton)
285- end
286- end
287- deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
308+ nops = length (OPS. types[degree]. types)
309+ if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN
288310 quote
289- i = tree . op
290- $ deg2_branch_skeleton
311+ $ setup
312+ grad_degn_eval (x_cumulators, d_cumulators, operators[ $ degree][op_idx])
291313 end
292314 else
293315 quote
294- op_idx = tree. op
295- Base. Cartesian. @nif ($ nbin, i -> i == op_idx, i -> $ deg2_branch_skeleton)
316+ $ setup
317+ Base. Cartesian. @nif (
318+ $ nops,
319+ i -> i == op_idx,
320+ i -> grad_degn_eval (x_cumulators, d_cumulators, operators[$ degree][i])
321+ )
296322 end
297323 end
324+ end
325+
326+ @generated function grad_degn_eval (
327+ x_cumulators:: NTuple{N} , d_cumulators:: NTuple{N} , op:: F
328+ ) where {N,F}
298329 quote
299- if tree. degree == 0
300- grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (mode))
301- elseif tree. degree == 1
302- $ deg1_branch
303- else
304- $ deg2_branch
330+ Base. Cartesian. @nexprs ($ N, i -> begin
331+ x_cumulator_i = x_cumulators[i]
332+ d_cumulator_i = d_cumulators[i]
333+ end )
334+ diff_op = _zygote_gradient (op, Val ($ N))
335+ @inbounds @simd for j in eachindex (x_cumulator_1)
336+ x = Base. Cartesian. @ncall ($ N, op, i -> x_cumulator_i[j])
337+ Base. Cartesian. @ntuple ($ N, i -> grad_i) = Base. Cartesian. @ncall (
338+ $ N, diff_op, i -> x_cumulator_i[j]
339+ )
340+ x_cumulator_1[j] = x
341+ for k in axes (d_cumulator_1, 1 )
342+ d_cumulator_1[k, j] = Base. Cartesian. @ncall (
343+ $ N, + , i -> grad_i * d_cumulator_i[k, j]
344+ )
345+ end
305346 end
347+ return ResultOk2 (x_cumulator_1, d_cumulator_1, true )
306348 end
307349end
308350
@@ -344,85 +386,4 @@ function grad_deg0_eval(
344386 return ResultOk2 (const_part, derivative_part, true )
345387end
346388
347- function grad_deg1_eval (
348- tree:: AbstractExpressionNode{T} ,
349- n_gradients,
350- index_tree:: Union{NodeIndex,Nothing} ,
351- cX:: AbstractMatrix{T} ,
352- op:: F ,
353- operators:: OperatorEnum ,
354- :: Val{mode} ,
355- ):: ResultOk2 where {T<: Number ,F,mode}
356- result = eval_grad_tree_array (
357- tree. l,
358- n_gradients,
359- index_tree === nothing ? index_tree : index_tree. l,
360- cX,
361- operators,
362- Val (mode),
363- )
364- ! result. ok && return result
365-
366- cumulator = result. x
367- dcumulator = result. dx
368- diff_op = _zygote_gradient (op, Val (1 ))
369- @inbounds @simd for j in axes (dcumulator, 2 )
370- x = op (cumulator[j]):: T
371- dx = diff_op (cumulator[j]):: T
372-
373- cumulator[j] = x
374- for k in axes (dcumulator, 1 )
375- dcumulator[k, j] = dx * dcumulator[k, j]
376- end
377- end
378- return result
379- end
380-
381- function grad_deg2_eval (
382- tree:: AbstractExpressionNode{T} ,
383- n_gradients,
384- index_tree:: Union{NodeIndex,Nothing} ,
385- cX:: AbstractMatrix{T} ,
386- op:: F ,
387- operators:: OperatorEnum ,
388- :: Val{mode} ,
389- ):: ResultOk2 where {T<: Number ,F,mode}
390- result_l = eval_grad_tree_array (
391- tree. l,
392- n_gradients,
393- index_tree === nothing ? index_tree : index_tree. l,
394- cX,
395- operators,
396- Val (mode),
397- )
398- ! result_l. ok && return result_l
399- result_r = eval_grad_tree_array (
400- tree. r,
401- n_gradients,
402- index_tree === nothing ? index_tree : index_tree. r,
403- cX,
404- operators,
405- Val (mode),
406- )
407- ! result_r. ok && return result_r
408-
409- cumulator_l = result_l. x
410- dcumulator_l = result_l. dx
411- cumulator_r = result_r. x
412- dcumulator_r = result_r. dx
413- diff_op = _zygote_gradient (op, Val (2 ))
414- @inbounds @simd for j in axes (dcumulator_l, 2 )
415- c1 = cumulator_l[j]
416- c2 = cumulator_r[j]
417- x = op (c1, c2):: T
418- dx1, dx2 = diff_op (c1, c2):: Tuple{T,T}
419- cumulator_l[j] = x
420- for k in axes (dcumulator_l, 1 )
421- dcumulator_l[k, j] = dx1 * dcumulator_l[k, j] + dx2 * dcumulator_r[k, j]
422- end
423- end
424-
425- return result_l
426- end
427-
428389end
0 commit comments