@@ -209,10 +209,14 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
209209end
210210
211211# jvp fallback scalar
212- __jacvec (args... ; kwargs... ) = JacVec (args... ; kwargs... )
213- function __jacvec (uf, u:: Number ; autodiff, kwargs... )
214- @assert autodiff isa AutoForwardDiff " Only ForwardDiff is currently supported."
215- return JVPScalar (uf, u, autodiff)
212+ function __jacvec (uf, u; autodiff, kwargs... )
213+ if ! (autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
214+ _ad = autodiff
215+ autodiff = ifelse (ForwardDiff. can_dual (eltype (u)), AutoForwardDiff (),
216+ AutoFiniteDiff ())
217+ @warn " $(_ad) not supported for JacVec. Using $(autodiff) instead."
218+ end
219+ return u isa Number ? JVPScalar (uf, u, autodiff) : JacVec (uf, u; autodiff, kwargs... )
216220end
217221
218222@concrete mutable struct JVPScalar
@@ -221,10 +225,17 @@ end
221225 autodiff
222226end
223227
224- function Base.:* (jvp:: JVPScalar , v)
225- T = typeof (ForwardDiff. Tag (typeof (jvp. uf), typeof (jvp. u)))
226- out = jvp. uf (ForwardDiff. Dual {T} (jvp. u, v))
227- return ForwardDiff. extract_derivative (T, out)
228+ function Base.:* (jvp:: JVPScalar , v:: Number )
229+ if jvp. autodiff isa AutoForwardDiff
230+ T = typeof (ForwardDiff. Tag (typeof (jvp. uf), typeof (jvp. u)))
231+ out = jvp. uf (ForwardDiff. Dual {T} (jvp. u, v))
232+ return ForwardDiff. extract_derivative (T, out)
233+ elseif jvp. autodiff isa AutoFiniteDiff
234+ J = FiniteDiff. finite_difference_derivative (jvp. uf, jvp. u, jvp. autodiff. fdtype)
235+ return J * v
236+ else
237+ error (" Only ForwardDiff & FiniteDiff is currently supported." )
238+ end
228239end
229240
230241# Generic Handling of Krylov Methods for Normal Form Linear Solves
0 commit comments