@@ -335,27 +335,73 @@ Reverse-mode evaluation of an expression tree given in `f`.
335335 * This function assumes that `f.reverse_storage` has been initialized with 0.0.
336336"""
337337function _reverse_eval (f:: _SubexpressionStorage )
338- @assert length (f. reverse_storage) >= length (f. nodes )
339- @assert length (f. partials_storage) >= length (f. nodes )
338+ @assert length (f. reverse_storage) >= _length (f. sizes )
339+ @assert length (f. partials_storage) >= _length (f. sizes )
340340 # f.nodes is already in order such that parents always appear before
341341 # children so a forward pass through nodes is a backwards pass through the
342342 # tree.
343- f. reverse_storage[1 ] = one (Float64)
344- for k in 2 : length (f. nodes)
343+ children_arr = SparseArrays. rowvals (f. adj)
344+ for i in _storage_range (f. sizes, 1 )
345+ f. reverse_storage[i] = one (Float64)
346+ end
347+ for k in 1 : length (f. nodes)
348+ @show f. reverse_storage
345349 node = f. nodes[k]
346- if node. type == Nonlinear. NODE_VALUE ||
347- node. type == Nonlinear. NODE_LOGIC ||
348- node. type == Nonlinear. NODE_COMPARISON ||
349- node. type == Nonlinear. NODE_PARAMETER
350+ children_indices = SparseArrays. nzrange (f. adj, k)
351+ if node. type == MOI. Nonlinear. NODE_CALL_MULTIVARIATE
352+ if node. index in
353+ eachindex (MOI. Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS)
354+ op = MOI. Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS[node. index]
355+ if op == :vect
356+ @assert _eachindex (f. sizes, k) ==
357+ eachindex (children_indices)
358+ for j in eachindex (children_indices)
359+ ix = children_arr[children_indices[j]]
360+ rev_parent_j = @j f. reverse_storage[k]
361+ # partial is 1 so we can ignore it
362+ @s f. reverse_storage[ix] = rev_parent_j
363+ end
364+ continue
365+ elseif op == :dot
366+ # Node `k` is scalar, the jacobian w.r.t. each vectorized input
367+ # child is a row vector whose entries are stored in `f.partials_storage`
368+ rev_parent = @s f. reverse_storage[k]
369+ for j in
370+ _eachindex (f. sizes, children_arr[children_indices[1 ]])
371+ for child_idx in children_indices
372+ ix = children_arr[child_idx]
373+ partial = @j f. partials_storage[ix]
374+ val = ifelse (
375+ rev_parent == 0.0 && ! isfinite (partial),
376+ rev_parent,
377+ rev_parent * partial,
378+ )
379+ @j f. reverse_storage[ix] = val
380+ end
381+ end
382+ continue
383+ end
384+ end
385+ elseif node. type != MOI. Nonlinear. NODE_CALL_UNIVARIATE
350386 continue
351387 end
352- rev_parent = f. reverse_storage[node. parent]
353- partial = f. partials_storage[k]
354- f. reverse_storage[k] = ifelse (
355- rev_parent == 0.0 && ! isfinite (partial),
356- rev_parent,
357- rev_parent * partial,
358- )
388+ # Node `k` has same size as its children.
389+ # The Jacobian (between the vectorized versions) is diagonal and the diagonal entries
390+ # are stored in `f.partials_storage`
391+ for j in _eachindex (f. sizes, k)
392+ rev_parent = @j f. reverse_storage[k]
393+ for child_idx in children_indices
394+ ix = children_arr[child_idx]
395+ @assert _size (f. sizes, k) == _size (f. sizes, ix)
396+ partial = @j f. partials_storage[ix]
397+ val = ifelse (
398+ rev_parent == 0.0 && ! isfinite (partial),
399+ rev_parent,
400+ rev_parent * partial,
401+ )
402+ @j f. reverse_storage[ix] = val
403+ end
404+ end
359405 end
360406 return
361407end
@@ -406,12 +452,12 @@ function _extract_reverse_pass_inner(
406452 subexpressions:: AbstractVector{T} ,
407453 scale:: T ,
408454) where {T}
409- @assert length (f. reverse_storage) >= length (f. nodes )
455+ @assert length (f. reverse_storage) >= _length (f. sizes )
410456 for (k, node) in enumerate (f. nodes)
411457 if node. type == Nonlinear. NODE_VARIABLE
412- output[node. index] += scale * f. reverse_storage[k]
458+ output[node. index] += scale * @s f. reverse_storage[k]
413459 elseif node. type == Nonlinear. NODE_SUBEXPRESSION
414- subexpressions[node. index] += scale * f. reverse_storage[k]
460+ subexpressions[node. index] += scale * @s f. reverse_storage[k]
415461 end
416462 end
417463 return
0 commit comments