@@ -2,11 +2,11 @@ module DiffEqBaseForwardDiffExt
22
33using DiffEqBase, ForwardDiff
44using DiffEqBase. ArrayInterface
5- using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
6- import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
5+ using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution,
6+ RecursiveArrayTools, reduce_tup
7+ import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
78 promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
8- InternalITP,
9- nextfloat_tdir, promote_dual
9+ InternalITP, nextfloat_tdir
1010
1111eltypedual (x) = eltype (x) <: ForwardDiff.Dual
1212isdualtype (:: Type{<:ForwardDiff.Dual} ) = true
@@ -138,6 +138,322 @@ function promote_dual(::Type{T},
138138 ForwardDiff. Dual{T3, promote_dual (V, V2), N}
139139end
140140
141+ """
142+ promote_dual(::Type{T},::Type{T2})
143+
144+
145+ Is like the number promotion system, but always prefers a dual number type above
146+ anything else. For higher order differentiation, it returns the most dualiest of
147+ them all. This is then used to promote `u0` into the suspected highest differentiation
148+ space for solving the equation.
149+ """
150+ promote_dual (:: Type{T} , :: Type{T2} ) where {T, T2} = T
151+
152+ # `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
153+ # Therefore, they can be type stable even with heterogeneous input types.
154+ # We also don't care about allocating any temporaries with them, as it should
155+ # all be unrolled and optimized away.
156+ # Being unrolled also means const prop can work for things like
157+ # `mapreduce(f, op, propertynames(x))`
158+ # where `f` may call `getproperty` and thus have return type dependent
159+ # on the particular symbol.
160+ # `mapreduce` hasn't received any such specialization.
161+ @inline diffeqmapreduce (f:: F , op:: OP , x:: Tuple ) where {F, OP} = reduce_tup (op, map (f, x))
162+ @inline function diffeqmapreduce (f:: F , op:: OP , x:: NamedTuple ) where {F, OP}
163+ reduce_tup (op, map (f, x))
164+ end
165+ # For other container types, we probably just want to call `mapreduce`
166+ @inline diffeqmapreduce (f:: F , op:: OP , x) where {F, OP} = mapreduce (f, op, x, init = Any)
167+
168+ struct DualEltypeChecker{T, T2}
169+ x:: T
170+ counter:: T2
171+ end
172+
173+ getval (:: Val{I} ) where {I} = I
174+ getval (:: Type{Val{I}} ) where {I} = I
175+ getval (I:: Int ) = I
176+
177+ const DUALCHECK_RECURSION_MAX = 10
178+
179+ function (dec:: DualEltypeChecker )(:: Val{Y} ) where {Y}
180+ isdefined (dec. x, Y) || return Any
181+ getval (dec. counter) >= DUALCHECK_RECURSION_MAX && return Any
182+ anyeltypedual (getfield (dec. x, Y), Val{getval (dec. counter)})
183+ end
184+
185+ # Untyped dispatch: catch composite types, check all of their fields
186+ """
187+ anyeltypedual(x)
188+
189+
190+ Searches through a type to see if any of its values are parameters. This is used to
191+ then promote other values to match the dual type. For example, if a user passes a parameter
192+
193+ which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
194+ will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
195+ to a dual number before the solve. Worse still, this needs to be done in the case of
196+ `f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
197+ way to calculate the required state type.
198+
199+ But given the properties of automatic differentiation requiring that differentiation of parameters
200+ implies differentiation of state, we assume any dual parameters implies differentiation of state
201+ and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
202+ to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
203+ may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
204+ upconversion is not done automatically, the user is required to upconvert all initial conditions
205+ themselves, for an example of how this can be confusing to a user see
206+ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
207+ """
208+ @generated function anyeltypedual (x, :: Type{Val{counter}} ) where {counter}
209+ x = x. name === Core. Compiler. typename (Type) ? x. parameters[1 ] : x
210+ if isdualtype (x)
211+ :($ x)
212+ elseif fieldnames (x) === ()
213+ :(Any)
214+ elseif counter < DUALCHECK_RECURSION_MAX
215+ T = diffeqmapreduce (x -> anyeltypedual (x, Val{counter + 1 }), promote_dual,
216+ x. parameters)
217+ if T === Any || isconcretetype (T)
218+ :($ T)
219+ else
220+ :(diffeqmapreduce (DualEltypeChecker ($ x, $ counter + 1 ), promote_dual,
221+ map (Val, fieldnames ($ (typeof (x))))))
222+ end
223+ else
224+ :(Any)
225+ end
226+ end
227+
228+ const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
229+ Failed to automatically detect ForwardDiff compatability of
230+ the parameter object. In order for ForwardDiff.jl automatic
231+ differentiation to work on a solution object, the state of
232+ the differential equation or nonlinear solve (`u0`) needs to
233+ be converted to a Dual type which matches the values being
234+ differentiated. For example, for a loss function loss(p)
235+ where `p`` is a `Vector{Float64}`, this conversion is
236+ equivalent to:
237+
238+ ```julia
239+ # Convert u0 to match the new Dual element type of `p`
240+ _prob = remake(prob, u0 = eltype(p).(prob.u0))
241+ ```
242+
243+ In most cases, SciML tools are able to do this conversion
244+ automatically. However, it seems you have provided a
245+ parameter type for which this automatic conversion has failed.
246+
247+ To fix this, you can do the conversion yourself. For example,
248+ if you have a parameter vector being optimized `p` which is
249+ then put into an odd struct, you can manually convert `u0`
250+ to match `p`:
251+
252+ ```julia
253+ function loss(p)
254+ _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
255+ sol = solve(_prob, ...)
256+ # do stuff on sol
257+ end
258+ ```
259+
260+ Or you can define a dispatch on `DiffEqBase.anyeltypedual`
261+ which tells the system what fields to interpret as the
262+ differentiable parts. For example, to support ODESolutions
263+ as parameters we tell it the data is `sol.u` and `sol.t` via:
264+
265+ ```julia
266+ function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
267+ DiffEqBase.anyeltypedual((sol.u, sol.t))
268+ end
269+ ```
270+
271+ To opt a type out of the dual checking, define an overload
272+ that returns Any. For example:
273+
274+ ```julia
275+ function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter}
276+ Any
277+ end
278+ ```
279+
280+ If you have defined this on a common type which should
281+ be more generally supported, please open a pull request
282+ adding this dispatch. If you need help defining this dispatch,
283+ feel free to open an issue.
284+ """
285+
286+ struct ForwardDiffAutomaticDetectionFailure <: Exception end
287+
288+ function Base. showerror (io:: IO , e:: ForwardDiffAutomaticDetectionFailure )
289+ print (io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE)
290+ end
291+
292+ function anyeltypedual (:: Type{Union{}} )
293+ throw (ForwardDiffAutomaticDetectionFailure ())
294+ end
295+
296+ function anyeltypedual (:: Type{<:AbstractTimeseriesSolution{T, N}} ,
297+ :: Type{Val{counter}} = Val{0 }) where {T, N, counter}
298+ anyeltypedual (T)
299+ end
300+
301+ function anyeltypedual (
302+ :: Type{T} ,
303+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
304+ NonlinearProblem{
305+ uType, iip, pType}} where {uType, iip, pType}
306+ return anyeltypedual ((uType, pType), Val{counter})
307+ end
308+
309+ function anyeltypedual (
310+ :: Type{T} ,
311+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
312+ NonlinearLeastSquaresProblem{
313+ uType, iip, pType}} where {uType, iip, pType}
314+ return anyeltypedual ((uType, pType), Val{counter})
315+ end
316+
317+ function anyeltypedual (x:: SciMLBase.RecipesBase.AbstractPlot ,
318+ :: Type{Val{counter}} = Val{0 }) where {counter}
319+ Any
320+ end
321+ function anyeltypedual (x:: Returns , :: Type{Val{counter}} = Val{0 }) where {counter}
322+ anyeltypedual (x. value, Val{counter})
323+ end
324+
325+ Base. @assume_effects :foldable function __anyeltypedual (:: Type{T} ) where {T}
326+ if T isa Union
327+ promote_dual (anyeltypedual (T. a), anyeltypedual (T. b))
328+ elseif hasproperty (T, :parameters )
329+ mapreduce (anyeltypedual, promote_dual, T. parameters; init = Any)
330+ else
331+ T
332+ end
333+ end
334+ function anyeltypedual (:: Type{T} , :: Type{Val{counter}} = Val{0 }) where {counter} where {T}
335+ __anyeltypedual (T)
336+ end
337+
338+ function anyeltypedual (:: Type{T} ,
339+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
340+ Union{AbstractArray, Set}}
341+ anyeltypedual (eltype (T))
342+ end
343+ Base. @pure function __anyeltypedual_ntuple (:: Type{T} ) where {T <: NTuple }
344+ if isconcretetype (eltype (T))
345+ return eltype (T)
346+ end
347+ if isempty (T. parameters)
348+ Any
349+ else
350+ mapreduce (anyeltypedual, promote_dual, T. parameters; init = Any)
351+ end
352+ end
353+ function anyeltypedual (
354+ :: Type{T} , :: Type{Val{counter}} = Val{0 }) where {counter} where {T <: NTuple }
355+ __anyeltypedual_ntuple (T)
356+ end
357+
358+ # Any in this context just means not Dual
359+ function anyeltypedual (
360+ x:: SciMLBase.NullParameters , :: Type{Val{counter}} = Val{0 }) where {counter}
361+ Any
362+ end
363+
364+ function anyeltypedual (sol:: RecursiveArrayTools.AbstractDiffEqArray , counter = 0 )
365+ diffeqmapreduce (anyeltypedual, promote_dual, (sol. u, sol. t))
366+ end
367+
368+ function anyeltypedual (prob:: Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem} ,
369+ :: Type{Val{counter}} = Val{0 }) where {counter}
370+ anyeltypedual ((prob. u0, prob. p, prob. tspan))
371+ end
372+
373+ function anyeltypedual (
374+ prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem} ,
375+ :: Type{Val{counter}} = Val{0 }) where {counter}
376+ anyeltypedual ((prob. u0, prob. p))
377+ end
378+
379+ function anyeltypedual (x:: Number , :: Type{Val{counter}} = Val{0 }) where {counter}
380+ anyeltypedual (typeof (x))
381+ end
382+ function anyeltypedual (
383+ x:: Union{String, Symbol} , :: Type{Val{counter}} = Val{0 }) where {counter}
384+ typeof (x)
385+ end
386+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
387+ :: Type{Val{counter}} = Val{0 }) where {counter} where {
388+ T < :
389+ Union{Number,
390+ Symbol,
391+ String}}
392+ anyeltypedual (T)
393+ end
394+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
395+ :: Type{Val{counter}} = Val{0 }) where {counter} where {
396+ T <: Union {
397+ AbstractArray{
398+ <: Number ,
399+ },
400+ Set{
401+ <: Number ,
402+ }}}
403+ anyeltypedual (eltype (x))
404+ end
405+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
406+ :: Type{Val{counter}} = Val{0 }) where {counter} where {N, T <: NTuple{N, <:Number} }
407+ anyeltypedual (eltype (x))
408+ end
409+
410+ # Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
411+ function anyeltypedual (x:: AbstractArray , :: Type{Val{counter}} = Val{0 }) where {counter}
412+ if isconcretetype (eltype (x))
413+ anyeltypedual (eltype (x))
414+ elseif ! isempty (x) && all (i -> isassigned (x, i), 1 : length (x)) &&
415+ counter < DUALCHECK_RECURSION_MAX
416+ _counter = Val{counter + 1 }
417+ mapreduce (y -> anyeltypedual (y, _counter), promote_dual, x)
418+ else
419+ # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
420+ # misses cases of
421+ Any
422+ end
423+ end
424+
425+ function anyeltypedual (x:: Set , :: Type{Val{counter}} = Val{0 }) where {counter}
426+ if isconcretetype (eltype (x))
427+ anyeltypedual (eltype (x))
428+ else
429+ # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
430+ Any
431+ end
432+ end
433+
434+ function anyeltypedual (x:: Tuple , :: Type{Val{counter}} = Val{0 }) where {counter}
435+ # Handle the empty tuple case separately for inference and to avoid mapreduce error
436+ if x === ()
437+ Any
438+ else
439+ diffeqmapreduce (anyeltypedual, promote_dual, x)
440+ end
441+ end
442+ function anyeltypedual (x:: Dict , :: Type{Val{counter}} = Val{0 }) where {counter}
443+ isempty (x) ? eltype (values (x)) : mapreduce (anyeltypedual, promote_dual, values (x))
444+ end
445+ function anyeltypedual (x:: NamedTuple , :: Type{Val{counter}} = Val{0 }) where {counter}
446+ anyeltypedual (values (x))
447+ end
448+
449+ function anyeltypedual (
450+ f:: SciMLBase.AbstractSciMLFunction , :: Type{Val{counter}} ) where {counter}
451+ Any
452+ end
453+
454+ anyeltypedual (:: @Kwargs {}, :: Type{Val{counter}} = Val{0 }) where {counter} = Any
455+ anyeltypedual (:: Type{@Kwargs{}} , :: Type{Val{counter}} = Val{0 }) where {counter} = Any
456+
141457# Opt out since these are using for preallocation, not differentiation
142458function anyeltypedual (x:: Union{ForwardDiff.AbstractConfig, Module} ,
143459 :: Type{Val{counter}} = Val{0 }) where {counter}
0 commit comments