@@ -31,15 +31,15 @@ function _fd_forward(fdm, f, rettype, y, activities)
3131 if rettype <: Union{Duplicated,DuplicatedNoNeed}
3232 all (ignores) && return zero_tangent (y)
3333 sig_arg_dval_vec, _ = to_vec (ẋs[.! ignores])
34- ret_deval_vec = FiniteDifferences. jvp (fdm, f_vec,
35- ( sig_arg_val_vec, sig_arg_dval_vec) )
34+ ret_deval_vec = FiniteDifferences. _jvp (fdm, f_vec,
35+ sig_arg_val_vec, sig_arg_dval_vec)
3636 return from_vec_out (ret_deval_vec)
3737 elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed}
3838 all (ignores) && return (var"1" = zero_tangent (y),)
3939 ret_dvals = map (ẋs[.! ignores]. .. ) do sig_args_dvals...
4040 sig_args_dvals_vec, _ = to_vec (sig_args_dvals)
41- ret_dval_vec = FiniteDifferences. jvp (fdm, f_vec,
42- ( sig_arg_val_vec, sig_args_dvals_vec) )
41+ ret_dval_vec = FiniteDifferences. _jvp (fdm, f_vec,
42+ sig_arg_val_vec, sig_args_dvals_vec)
4343 return from_vec_out (ret_dval_vec)
4444 end
4545 return NamedTuple {ntuple(Symbol, length(ret_dvals))} (ret_dvals)
@@ -59,6 +59,18 @@ function multi_tovec(active_return, vals)
5959 end
6060end
6161
62+ function j′vp (fdm, f_vec, ȳ, x)
63+ mat = transpose (first (FiniteDifferences. jacobian (fdm, f_vec, x)))
64+ result = zero (x)
65+ for i in 1 : length (ȳ)
66+ tp = @inbounds ȳ[i]
67+ if isfinite (tp) && ! iszero (tp)
68+ result .+ = mat[:, i] .* tp
69+ end
70+ end
71+ return result
72+ end
73+
6274#=
6375 _fd_reverse(fdm, f, ȳ, activities, active_return)
6476
@@ -98,13 +110,12 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return)
98110 if ! is_batch
99111 ȳ_extended = (ȳ, s̄igargs... )
100112 ȳ_extended_vec = multi_tovec (active_return, ȳ_extended)
101- fd_vec = only (FiniteDifferences . j′vp (fdm, f_vec, ȳ_extended_vec, sigargs_vec) )
113+ fd_vec = j′vp (fdm, f_vec, ȳ_extended_vec, sigargs_vec)
102114 fd = from_vec_in (fd_vec)
103115 else
104116 fd = Tuple (zip (map (ȳ, s̄igargs... ) do ȳ_extended...
105117 ȳ_extended_vec = multi_tovec (active_return, ȳ_extended)
106- fd_vec = only (FiniteDifferences. j′vp (fdm, f_vec, ȳ_extended_vec,
107- sigargs_vec))
118+ fd_vec = j′vp (fdm, f_vec, ȳ_extended_vec, sigargs_vec)
108119 return from_vec_in (fd_vec)
109120 end ... ))
110121 end
0 commit comments