Skip to content

Commit 3bb95d9

Browse files
Fix inference
1 parent c6c0c04 commit 3bb95d9

File tree

3 files changed

+329
-326
lines changed

3 files changed

+329
-326
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 320 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ module DiffEqBaseForwardDiffExt
22

33
using DiffEqBase, ForwardDiff
44
using DiffEqBase.ArrayInterface
5-
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
6-
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
5+
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution,
6+
RecursiveArrayTools, reduce_tup
7+
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
78
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
8-
InternalITP,
9-
nextfloat_tdir, promote_dual
9+
InternalITP, nextfloat_tdir
1010

1111
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
1212
isdualtype(::Type{<:ForwardDiff.Dual}) = true
@@ -138,6 +138,322 @@ function promote_dual(::Type{T},
138138
ForwardDiff.Dual{T3, promote_dual(V, V2), N}
139139
end
140140

141+
"""
142+
promote_dual(::Type{T},::Type{T2})
143+
144+
145+
Is like the number promotion system, but always prefers a dual number type above
146+
anything else. For higher order differentiation, it returns the most dualiest of
147+
them all. This is then used to promote `u0` into the suspected highest differentiation
148+
space for solving the equation.
149+
"""
150+
promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T
151+
152+
# `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
153+
# Therefore, they can be type stable even with heterogeneous input types.
154+
# We also don't care about allocating any temporaries with them, as it should
155+
# all be unrolled and optimized away.
156+
# Being unrolled also means const prop can work for things like
157+
# `mapreduce(f, op, propertynames(x))`
158+
# where `f` may call `getproperty` and thus have return type dependent
159+
# on the particular symbol.
160+
# `mapreduce` hasn't received any such specialization.
161+
@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x))
162+
@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP}
163+
reduce_tup(op, map(f, x))
164+
end
165+
# For other container types, we probably just want to call `mapreduce`
166+
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init = Any)
167+
168+
struct DualEltypeChecker{T, T2}
169+
x::T
170+
counter::T2
171+
end
172+
173+
getval(::Val{I}) where {I} = I
174+
getval(::Type{Val{I}}) where {I} = I
175+
getval(I::Int) = I
176+
177+
const DUALCHECK_RECURSION_MAX = 10
178+
179+
function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
180+
isdefined(dec.x, Y) || return Any
181+
getval(dec.counter) >= DUALCHECK_RECURSION_MAX && return Any
182+
anyeltypedual(getfield(dec.x, Y), Val{getval(dec.counter)})
183+
end
184+
185+
# Untyped dispatch: catch composite types, check all of their fields
186+
"""
187+
anyeltypedual(x)
188+
189+
190+
Searches through a type to see if any of its values are parameters. This is used to
191+
then promote other values to match the dual type. For example, if a user passes a parameter
192+
193+
which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
194+
will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
195+
to a dual number before the solve. Worse still, this needs to be done in the case of
196+
`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
197+
way to calculate the required state type.
198+
199+
But given the properties of automatic differentiation requiring that differentiation of parameters
200+
implies differentiation of state, we assume any dual parameters implies differentiation of state
201+
and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
202+
to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
203+
may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
204+
upconversion is not done automatically, the user is required to upconvert all initial conditions
205+
themselves, for an example of how this can be confusing to a user see
206+
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
207+
"""
208+
@generated function anyeltypedual(x, ::Type{Val{counter}}) where {counter}
209+
x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x
210+
if isdualtype(x)
211+
:($x)
212+
elseif fieldnames(x) === ()
213+
:(Any)
214+
elseif counter < DUALCHECK_RECURSION_MAX
215+
T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual,
216+
x.parameters)
217+
if T === Any || isconcretetype(T)
218+
:($T)
219+
else
220+
:(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual,
221+
map(Val, fieldnames($(typeof(x))))))
222+
end
223+
else
224+
:(Any)
225+
end
226+
end
227+
228+
const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
229+
Failed to automatically detect ForwardDiff compatability of
230+
the parameter object. In order for ForwardDiff.jl automatic
231+
differentiation to work on a solution object, the state of
232+
the differential equation or nonlinear solve (`u0`) needs to
233+
be converted to a Dual type which matches the values being
234+
differentiated. For example, for a loss function loss(p)
235+
where `p`` is a `Vector{Float64}`, this conversion is
236+
equivalent to:
237+
238+
```julia
239+
# Convert u0 to match the new Dual element type of `p`
240+
_prob = remake(prob, u0 = eltype(p).(prob.u0))
241+
```
242+
243+
In most cases, SciML tools are able to do this conversion
244+
automatically. However, it seems you have provided a
245+
parameter type for which this automatic conversion has failed.
246+
247+
To fix this, you can do the conversion yourself. For example,
248+
if you have a parameter vector being optimized `p` which is
249+
then put into an odd struct, you can manually convert `u0`
250+
to match `p`:
251+
252+
```julia
253+
function loss(p)
254+
_prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
255+
sol = solve(_prob, ...)
256+
# do stuff on sol
257+
end
258+
```
259+
260+
Or you can define a dispatch on `DiffEqBase.anyeltypedual`
261+
which tells the system what fields to interpret as the
262+
differentiable parts. For example, to support ODESolutions
263+
as parameters we tell it the data is `sol.u` and `sol.t` via:
264+
265+
```julia
266+
function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
267+
DiffEqBase.anyeltypedual((sol.u, sol.t))
268+
end
269+
```
270+
271+
To opt a type out of the dual checking, define an overload
272+
that returns Any. For example:
273+
274+
```julia
275+
function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter}
276+
Any
277+
end
278+
```
279+
280+
If you have defined this on a common type which should
281+
be more generally supported, please open a pull request
282+
adding this dispatch. If you need help defining this dispatch,
283+
feel free to open an issue.
284+
"""
285+
286+
struct ForwardDiffAutomaticDetectionFailure <: Exception end
287+
288+
function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure)
289+
print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE)
290+
end
291+
292+
function anyeltypedual(::Type{Union{}})
293+
throw(ForwardDiffAutomaticDetectionFailure())
294+
end
295+
296+
function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T, N}},
297+
::Type{Val{counter}} = Val{0}) where {T, N, counter}
298+
anyeltypedual(T)
299+
end
300+
301+
function anyeltypedual(
302+
::Type{T},
303+
::Type{Val{counter}} = Val{0}) where {counter} where {T <:
304+
NonlinearProblem{
305+
uType, iip, pType}} where {uType, iip, pType}
306+
return anyeltypedual((uType, pType), Val{counter})
307+
end
308+
309+
function anyeltypedual(
310+
::Type{T},
311+
::Type{Val{counter}} = Val{0}) where {counter} where {T <:
312+
NonlinearLeastSquaresProblem{
313+
uType, iip, pType}} where {uType, iip, pType}
314+
return anyeltypedual((uType, pType), Val{counter})
315+
end
316+
317+
function anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot,
318+
::Type{Val{counter}} = Val{0}) where {counter}
319+
Any
320+
end
321+
function anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter}
322+
anyeltypedual(x.value, Val{counter})
323+
end
324+
325+
Base.@assume_effects :foldable function __anyeltypedual(::Type{T}) where {T}
326+
if T isa Union
327+
promote_dual(anyeltypedual(T.a), anyeltypedual(T.b))
328+
elseif hasproperty(T, :parameters)
329+
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
330+
else
331+
T
332+
end
333+
end
334+
function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T}
335+
__anyeltypedual(T)
336+
end
337+
338+
function anyeltypedual(::Type{T},
339+
::Type{Val{counter}} = Val{0}) where {counter} where {T <:
340+
Union{AbstractArray, Set}}
341+
anyeltypedual(eltype(T))
342+
end
343+
Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
344+
if isconcretetype(eltype(T))
345+
return eltype(T)
346+
end
347+
if isempty(T.parameters)
348+
Any
349+
else
350+
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
351+
end
352+
end
353+
function anyeltypedual(
354+
::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple}
355+
__anyeltypedual_ntuple(T)
356+
end
357+
358+
# Any in this context just means not Dual
359+
function anyeltypedual(
360+
x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter}
361+
Any
362+
end
363+
364+
function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0)
365+
diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t))
366+
end
367+
368+
function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem},
369+
::Type{Val{counter}} = Val{0}) where {counter}
370+
anyeltypedual((prob.u0, prob.p, prob.tspan))
371+
end
372+
373+
function anyeltypedual(
374+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem},
375+
::Type{Val{counter}} = Val{0}) where {counter}
376+
anyeltypedual((prob.u0, prob.p))
377+
end
378+
379+
function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter}
380+
anyeltypedual(typeof(x))
381+
end
382+
function anyeltypedual(
383+
x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter}
384+
typeof(x)
385+
end
386+
function anyeltypedual(x::Union{AbstractArray{T}, Set{T}},
387+
::Type{Val{counter}} = Val{0}) where {counter} where {
388+
T <:
389+
Union{Number,
390+
Symbol,
391+
String}}
392+
anyeltypedual(T)
393+
end
394+
function anyeltypedual(x::Union{AbstractArray{T}, Set{T}},
395+
::Type{Val{counter}} = Val{0}) where {counter} where {
396+
T <: Union{
397+
AbstractArray{
398+
<:Number,
399+
},
400+
Set{
401+
<:Number,
402+
}}}
403+
anyeltypedual(eltype(x))
404+
end
405+
function anyeltypedual(x::Union{AbstractArray{T}, Set{T}},
406+
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
407+
anyeltypedual(eltype(x))
408+
end
409+
410+
# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
411+
function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {counter}
412+
if isconcretetype(eltype(x))
413+
anyeltypedual(eltype(x))
414+
elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) &&
415+
counter < DUALCHECK_RECURSION_MAX
416+
_counter = Val{counter + 1}
417+
mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x)
418+
else
419+
# This fallback to Any is required since otherwise we cannot handle `undef` in all cases
420+
# misses cases of
421+
Any
422+
end
423+
end
424+
425+
function anyeltypedual(x::Set, ::Type{Val{counter}} = Val{0}) where {counter}
426+
if isconcretetype(eltype(x))
427+
anyeltypedual(eltype(x))
428+
else
429+
# This fallback to Any is required since otherwise we cannot handle `undef` in all cases
430+
Any
431+
end
432+
end
433+
434+
function anyeltypedual(x::Tuple, ::Type{Val{counter}} = Val{0}) where {counter}
435+
# Handle the empty tuple case separately for inference and to avoid mapreduce error
436+
if x === ()
437+
Any
438+
else
439+
diffeqmapreduce(anyeltypedual, promote_dual, x)
440+
end
441+
end
442+
function anyeltypedual(x::Dict, ::Type{Val{counter}} = Val{0}) where {counter}
443+
isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x))
444+
end
445+
function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {counter}
446+
anyeltypedual(values(x))
447+
end
448+
449+
function anyeltypedual(
450+
f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter}
451+
Any
452+
end
453+
454+
anyeltypedual(::@Kwargs{}, ::Type{Val{counter}} = Val{0}) where {counter} = Any
455+
anyeltypedual(::Type{@Kwargs{}}, ::Type{Val{counter}} = Val{0}) where {counter} = Any
456+
141457
# Opt out since these are using for preallocation, not differentiation
142458
function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module},
143459
::Type{Val{counter}} = Val{0}) where {counter}

src/DiffEqBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true
110110

111111
eltypedual(x) = false
112112
promote_u0(::Nothing, p, t0) = nothing
113-
isdualtype(::Type{T}) where {T} = true
113+
isdualtype(::Type{T}) where {T} = false
114114

115115
## Types
116116

0 commit comments

Comments
 (0)