@@ -247,6 +247,83 @@ function _forward_eval(
247247 tmp_dot += v1 * v2
248248 end
249249 @s f. forward_storage[k] = tmp_dot
250+ elseif node. index == 12 # hcat
251+ idx1, idx2 = children_indices
252+ ix1 = children_arr[idx1]
253+ ix2 = children_arr[idx2]
254+ nb_cols1 = f. sizes. ndims[ix1] <= 1 ? 1 : _size (f. sizes, ix1, 2 )
255+ col_size = f. sizes. ndims[ix1] == 0 ? 1 : _size (f. sizes, k, 1 )
256+ for j in _eachindex (f. sizes, ix1)
257+ @j f. partials_storage[ix1] = one (T)
258+ val = @j f. forward_storage[ix1]
259+ @j f. forward_storage[k] = val
260+ end
261+ for j in _eachindex (f. sizes, ix2)
262+ @j f. partials_storage[ix2] = one (T)
263+ val = @j f. forward_storage[ix2]
264+ _setindex! (
265+ f. forward_storage,
266+ val,
267+ f. sizes,
268+ k,
269+ j + nb_cols1 * col_size,
270+ )
271+ end
272+ elseif node. index == 13 # vcat
273+ idx1, idx2 = children_indices
274+ ix1 = children_arr[idx1]
275+ ix2 = children_arr[idx2]
276+ nb_rows1 = f. sizes. ndims[ix1] <= 1 ? 1 : _size (f. sizes, ix1, 1 )
277+ nb_rows2 = f. sizes. ndims[ix2] <= 1 ? 1 : _size (f. sizes, ix2, 1 )
278+ nb_rows = nb_rows1 + nb_rows2
279+ for j in _eachindex (f. sizes, ix1)
280+ @j f. partials_storage[ix1] = one (T)
281+ val = @j f. forward_storage[ix1]
282+ _setindex! (
283+ f. forward_storage,
284+ val,
285+ f. sizes,
286+ k,
287+ div (j- 1 , nb_rows1) * nb_rows + 1 + (j- 1 ) % nb_rows1,
288+ )
289+ end
290+ for j in _eachindex (f. sizes, ix2)
291+ @j f. partials_storage[ix2] = one (T)
292+ val = @j f. forward_storage[ix2]
293+ _setindex! (
294+ f. forward_storage,
295+ val,
296+ f. sizes,
297+ k,
298+ div (j- 1 , nb_rows1) * nb_rows +
299+ 1 +
300+ (j- 1 ) % nb_rows1 +
301+ nb_rows1,
302+ )
303+ end
304+ elseif node. index == 14 # norm
305+ ix = children_arr[children_indices[1 ]]
306+ tmp_norm_squared = zero (T)
307+ for j in _eachindex (f. sizes, ix)
308+ v = @j f. forward_storage[ix]
309+ tmp_norm_squared += v * v
310+ end
311+ @s f. forward_storage[k] = sqrt (tmp_norm_squared)
312+ for j in _eachindex (f. sizes, ix)
313+ v = @j f. forward_storage[ix]
314+ if tmp_norm_squared == 0
315+ @j f. partials_storage[ix] = zero (T)
316+ else
317+ @j f. partials_storage[ix] = v / @s f. forward_storage[k]
318+ end
319+ end
320+ elseif node. index == 16 # row
321+ for j in _eachindex (f. sizes, k)
322+ ix = children_arr[children_indices[j]]
323+ @s f. partials_storage[ix] = one (T)
324+ val = @s f. forward_storage[ix]
325+ @j f. forward_storage[k] = val
326+ end
250327 else # atan, min, max
251328 f_input = _UnsafeVectorView (d. jac_storage, N)
252329 ∇f = _UnsafeVectorView (d. user_output_buffer, N)
@@ -380,6 +457,149 @@ function _reverse_eval(f::_SubexpressionStorage)
380457 end
381458 end
382459 continue
460+ elseif op == :hcat
461+ idx1, idx2 = children_indices
462+ ix1 = children_arr[idx1]
463+ ix2 = children_arr[idx2]
464+ nb_cols1 =
465+ f. sizes. ndims[ix1] <= 1 ? 1 : _size (f. sizes, ix1, 2 )
466+ col_size =
467+ f. sizes. ndims[ix1] == 0 ? 1 : _size (f. sizes, k, 1 )
468+ for j in _eachindex (f. sizes, ix1)
469+ partial = @j f. partials_storage[ix1]
470+ val = ifelse (
471+ _getindex (f. reverse_storage, f. sizes, k, j) ==
472+ 0.0 && ! isfinite (partial),
473+ _getindex (f. reverse_storage, f. sizes, k, j),
474+ _getindex (f. reverse_storage, f. sizes, k, j) *
475+ partial,
476+ )
477+ @j f. reverse_storage[ix1] = val
478+ end
479+ for j in _eachindex (f. sizes, ix2)
480+ partial = @j f. partials_storage[ix2]
481+ val = ifelse (
482+ _getindex (
483+ f. reverse_storage,
484+ f. sizes,
485+ k,
486+ j + nb_cols1 * col_size,
487+ ) == 0.0 && ! isfinite (partial),
488+ _getindex (
489+ f. reverse_storage,
490+ f. sizes,
491+ k,
492+ j + nb_cols1 * col_size,
493+ ),
494+ _getindex (
495+ f. reverse_storage,
496+ f. sizes,
497+ k,
498+ j + nb_cols1 * col_size,
499+ ) * partial,
500+ )
501+ @j f. reverse_storage[ix2] = val
502+ end
503+ continue
504+ elseif op == :vcat
505+ idx1, idx2 = children_indices
506+ ix1 = children_arr[idx1]
507+ ix2 = children_arr[idx2]
508+ nb_rows1 =
509+ f. sizes. ndims[ix1] <= 1 ? 1 : _size (f. sizes, ix1, 1 )
510+ nb_rows2 =
511+ f. sizes. ndims[ix2] <= 1 ? 1 : _size (f. sizes, ix2, 1 )
512+ nb_rows = nb_rows1 + nb_rows2
513+ row_size =
514+ f. sizes. ndims[ix1] == 0 ? 1 : _size (f. sizes, k, 2 )
515+ for j in _eachindex (f. sizes, ix1)
516+ partial = @j f. partials_storage[ix1]
517+ val = ifelse (
518+ _getindex (
519+ f. reverse_storage,
520+ f. sizes,
521+ k,
522+ div (j- 1 , nb_rows1) * nb_rows +
523+ 1 +
524+ (j- 1 ) % nb_rows1,
525+ ) == 0.0 && ! isfinite (partial),
526+ _getindex (
527+ f. reverse_storage,
528+ f. sizes,
529+ k,
530+ div (j- 1 , nb_rows1) * nb_rows +
531+ 1 +
532+ (j- 1 ) % nb_rows1,
533+ ),
534+ _getindex (
535+ f. reverse_storage,
536+ f. sizes,
537+ k,
538+ div (j- 1 , nb_rows1) * nb_rows +
539+ 1 +
540+ (j- 1 ) % nb_rows1,
541+ ) * partial,
542+ )
543+ @j f. reverse_storage[ix1] = val
544+ end
545+ for j in _eachindex (f. sizes, ix2)
546+ partial = @j f. partials_storage[ix2]
547+ val = ifelse (
548+ _getindex (
549+ f. reverse_storage,
550+ f. sizes,
551+ k,
552+ div (j- 1 , nb_rows1) * nb_rows +
553+ 1 +
554+ (j- 1 ) % nb_rows1 +
555+ nb_rows1,
556+ ) == 0.0 && ! isfinite (partial),
557+ _getindex (
558+ f. reverse_storage,
559+ f. sizes,
560+ k,
561+ div (j- 1 , nb_rows1) * nb_rows +
562+ 1 +
563+ (j- 1 ) % nb_rows1 +
564+ nb_rows1,
565+ ),
566+ _getindex (
567+ f. reverse_storage,
568+ f. sizes,
569+ k,
570+ div (j- 1 , nb_rows1) * nb_rows +
571+ 1 +
572+ (j- 1 ) % nb_rows1 +
573+ nb_rows1,
574+ ) * partial,
575+ )
576+ @j f. reverse_storage[ix2] = val
577+ end
578+ continue
579+ elseif op == :norm
580+ # Node `k` is scalar, the jacobian w.r.t. the vectorized input
581+ # child is a row vector whose entries are stored in `f.partials_storage`
582+ rev_parent = @s f. reverse_storage[k]
583+ for j in
584+ _eachindex (f. sizes, children_arr[children_indices[1 ]])
585+ ix = children_arr[children_indices[1 ]]
586+ partial = @j f. partials_storage[ix]
587+ val = ifelse (
588+ rev_parent == 0.0 && ! isfinite (partial),
589+ rev_parent,
590+ rev_parent * partial,
591+ )
592+ @j f. reverse_storage[ix] = val
593+ end
594+ continue
595+ elseif op == :row
596+ for j in _eachindex (f. sizes, k)
597+ ix = children_arr[children_indices[j]]
598+ rev_parent_j = @j f. reverse_storage[k]
599+ # partial is 1 so we can ignore it
600+ @s f. reverse_storage[ix] = rev_parent_j
601+ end
602+ continue
383603 end
384604 end
385605 elseif node. type != MOI. Nonlinear. NODE_CALL_UNIVARIATE
0 commit comments