@@ -2,11 +2,11 @@ module DiffEqBaseForwardDiffExt
2
2
3
3
using DiffEqBase, ForwardDiff
4
4
using DiffEqBase. ArrayInterface
5
- using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
6
- import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
5
+ using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution,
6
+ RecursiveArrayTools, reduce_tup
7
+ import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
7
8
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
8
- InternalITP,
9
- nextfloat_tdir, promote_dual
9
+ InternalITP, nextfloat_tdir
10
10
11
11
eltypedual (x) = eltype (x) <: ForwardDiff.Dual
12
12
isdualtype (:: Type{<:ForwardDiff.Dual} ) = true
@@ -138,6 +138,322 @@ function promote_dual(::Type{T},
138
138
ForwardDiff. Dual{T3, promote_dual (V, V2), N}
139
139
end
140
140
141
+ """
142
+ promote_dual(::Type{T},::Type{T2})
143
+
144
+
145
+ Is like the number promotion system, but always prefers a dual number type above
146
+ anything else. For higher order differentiation, it returns the most dualiest of
147
+ them all. This is then used to promote `u0` into the suspected highest differentiation
148
+ space for solving the equation.
149
+ """
150
+ promote_dual (:: Type{T} , :: Type{T2} ) where {T, T2} = T
151
+
152
+ # `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
153
+ # Therefore, they can be type stable even with heterogeneous input types.
154
+ # We also don't care about allocating any temporaries with them, as it should
155
+ # all be unrolled and optimized away.
156
+ # Being unrolled also means const prop can work for things like
157
+ # `mapreduce(f, op, propertynames(x))`
158
+ # where `f` may call `getproperty` and thus have return type dependent
159
+ # on the particular symbol.
160
+ # `mapreduce` hasn't received any such specialization.
161
+ @inline diffeqmapreduce (f:: F , op:: OP , x:: Tuple ) where {F, OP} = reduce_tup (op, map (f, x))
162
+ @inline function diffeqmapreduce (f:: F , op:: OP , x:: NamedTuple ) where {F, OP}
163
+ reduce_tup (op, map (f, x))
164
+ end
165
+ # For other container types, we probably just want to call `mapreduce`
166
+ @inline diffeqmapreduce (f:: F , op:: OP , x) where {F, OP} = mapreduce (f, op, x, init = Any)
167
+
168
+ struct DualEltypeChecker{T, T2}
169
+ x:: T
170
+ counter:: T2
171
+ end
172
+
173
+ getval (:: Val{I} ) where {I} = I
174
+ getval (:: Type{Val{I}} ) where {I} = I
175
+ getval (I:: Int ) = I
176
+
177
+ const DUALCHECK_RECURSION_MAX = 10
178
+
179
+ function (dec:: DualEltypeChecker )(:: Val{Y} ) where {Y}
180
+ isdefined (dec. x, Y) || return Any
181
+ getval (dec. counter) >= DUALCHECK_RECURSION_MAX && return Any
182
+ anyeltypedual (getfield (dec. x, Y), Val{getval (dec. counter)})
183
+ end
184
+
185
+ # Untyped dispatch: catch composite types, check all of their fields
186
+ """
187
+ anyeltypedual(x)
188
+
189
+
190
+ Searches through a type to see if any of its values are parameters. This is used to
191
+ then promote other values to match the dual type. For example, if a user passes a parameter
192
+
193
+ which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
194
+ will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
195
+ to a dual number before the solve. Worse still, this needs to be done in the case of
196
+ `f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
197
+ way to calculate the required state type.
198
+
199
+ But given the properties of automatic differentiation requiring that differentiation of parameters
200
+ implies differentiation of state, we assume any dual parameters implies differentiation of state
201
+ and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
202
+ to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
203
+ may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
204
+ upconversion is not done automatically, the user is required to upconvert all initial conditions
205
+ themselves, for an example of how this can be confusing to a user see
206
+ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
207
+ """
208
+ @generated function anyeltypedual (x, :: Type{Val{counter}} ) where {counter}
209
+ x = x. name === Core. Compiler. typename (Type) ? x. parameters[1 ] : x
210
+ if isdualtype (x)
211
+ :($ x)
212
+ elseif fieldnames (x) === ()
213
+ :(Any)
214
+ elseif counter < DUALCHECK_RECURSION_MAX
215
+ T = diffeqmapreduce (x -> anyeltypedual (x, Val{counter + 1 }), promote_dual,
216
+ x. parameters)
217
+ if T === Any || isconcretetype (T)
218
+ :($ T)
219
+ else
220
+ :(diffeqmapreduce (DualEltypeChecker ($ x, $ counter + 1 ), promote_dual,
221
+ map (Val, fieldnames ($ (typeof (x))))))
222
+ end
223
+ else
224
+ :(Any)
225
+ end
226
+ end
227
+
228
+ const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
229
+ Failed to automatically detect ForwardDiff compatability of
230
+ the parameter object. In order for ForwardDiff.jl automatic
231
+ differentiation to work on a solution object, the state of
232
+ the differential equation or nonlinear solve (`u0`) needs to
233
+ be converted to a Dual type which matches the values being
234
+ differentiated. For example, for a loss function loss(p)
235
+ where `p`` is a `Vector{Float64}`, this conversion is
236
+ equivalent to:
237
+
238
+ ```julia
239
+ # Convert u0 to match the new Dual element type of `p`
240
+ _prob = remake(prob, u0 = eltype(p).(prob.u0))
241
+ ```
242
+
243
+ In most cases, SciML tools are able to do this conversion
244
+ automatically. However, it seems you have provided a
245
+ parameter type for which this automatic conversion has failed.
246
+
247
+ To fix this, you can do the conversion yourself. For example,
248
+ if you have a parameter vector being optimized `p` which is
249
+ then put into an odd struct, you can manually convert `u0`
250
+ to match `p`:
251
+
252
+ ```julia
253
+ function loss(p)
254
+ _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
255
+ sol = solve(_prob, ...)
256
+ # do stuff on sol
257
+ end
258
+ ```
259
+
260
+ Or you can define a dispatch on `DiffEqBase.anyeltypedual`
261
+ which tells the system what fields to interpret as the
262
+ differentiable parts. For example, to support ODESolutions
263
+ as parameters we tell it the data is `sol.u` and `sol.t` via:
264
+
265
+ ```julia
266
+ function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
267
+ DiffEqBase.anyeltypedual((sol.u, sol.t))
268
+ end
269
+ ```
270
+
271
+ To opt a type out of the dual checking, define an overload
272
+ that returns Any. For example:
273
+
274
+ ```julia
275
+ function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter}
276
+ Any
277
+ end
278
+ ```
279
+
280
+ If you have defined this on a common type which should
281
+ be more generally supported, please open a pull request
282
+ adding this dispatch. If you need help defining this dispatch,
283
+ feel free to open an issue.
284
+ """
285
+
286
+ struct ForwardDiffAutomaticDetectionFailure <: Exception end
287
+
288
+ function Base. showerror (io:: IO , e:: ForwardDiffAutomaticDetectionFailure )
289
+ print (io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE)
290
+ end
291
+
292
+ function anyeltypedual (:: Type{Union{}} )
293
+ throw (ForwardDiffAutomaticDetectionFailure ())
294
+ end
295
+
296
+ function anyeltypedual (:: Type{<:AbstractTimeseriesSolution{T, N}} ,
297
+ :: Type{Val{counter}} = Val{0 }) where {T, N, counter}
298
+ anyeltypedual (T)
299
+ end
300
+
301
+ function anyeltypedual (
302
+ :: Type{T} ,
303
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
304
+ NonlinearProblem{
305
+ uType, iip, pType}} where {uType, iip, pType}
306
+ return anyeltypedual ((uType, pType), Val{counter})
307
+ end
308
+
309
+ function anyeltypedual (
310
+ :: Type{T} ,
311
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
312
+ NonlinearLeastSquaresProblem{
313
+ uType, iip, pType}} where {uType, iip, pType}
314
+ return anyeltypedual ((uType, pType), Val{counter})
315
+ end
316
+
317
+ function anyeltypedual (x:: SciMLBase.RecipesBase.AbstractPlot ,
318
+ :: Type{Val{counter}} = Val{0 }) where {counter}
319
+ Any
320
+ end
321
+ function anyeltypedual (x:: Returns , :: Type{Val{counter}} = Val{0 }) where {counter}
322
+ anyeltypedual (x. value, Val{counter})
323
+ end
324
+
325
+ Base. @assume_effects :foldable function __anyeltypedual (:: Type{T} ) where {T}
326
+ if T isa Union
327
+ promote_dual (anyeltypedual (T. a), anyeltypedual (T. b))
328
+ elseif hasproperty (T, :parameters )
329
+ mapreduce (anyeltypedual, promote_dual, T. parameters; init = Any)
330
+ else
331
+ T
332
+ end
333
+ end
334
+ function anyeltypedual (:: Type{T} , :: Type{Val{counter}} = Val{0 }) where {counter} where {T}
335
+ __anyeltypedual (T)
336
+ end
337
+
338
+ function anyeltypedual (:: Type{T} ,
339
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {T < :
340
+ Union{AbstractArray, Set}}
341
+ anyeltypedual (eltype (T))
342
+ end
343
+ Base. @pure function __anyeltypedual_ntuple (:: Type{T} ) where {T <: NTuple }
344
+ if isconcretetype (eltype (T))
345
+ return eltype (T)
346
+ end
347
+ if isempty (T. parameters)
348
+ Any
349
+ else
350
+ mapreduce (anyeltypedual, promote_dual, T. parameters; init = Any)
351
+ end
352
+ end
353
+ function anyeltypedual (
354
+ :: Type{T} , :: Type{Val{counter}} = Val{0 }) where {counter} where {T <: NTuple }
355
+ __anyeltypedual_ntuple (T)
356
+ end
357
+
358
+ # Any in this context just means not Dual
359
+ function anyeltypedual (
360
+ x:: SciMLBase.NullParameters , :: Type{Val{counter}} = Val{0 }) where {counter}
361
+ Any
362
+ end
363
+
364
+ function anyeltypedual (sol:: RecursiveArrayTools.AbstractDiffEqArray , counter = 0 )
365
+ diffeqmapreduce (anyeltypedual, promote_dual, (sol. u, sol. t))
366
+ end
367
+
368
+ function anyeltypedual (prob:: Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem} ,
369
+ :: Type{Val{counter}} = Val{0 }) where {counter}
370
+ anyeltypedual ((prob. u0, prob. p, prob. tspan))
371
+ end
372
+
373
+ function anyeltypedual (
374
+ prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem} ,
375
+ :: Type{Val{counter}} = Val{0 }) where {counter}
376
+ anyeltypedual ((prob. u0, prob. p))
377
+ end
378
+
379
+ function anyeltypedual (x:: Number , :: Type{Val{counter}} = Val{0 }) where {counter}
380
+ anyeltypedual (typeof (x))
381
+ end
382
+ function anyeltypedual (
383
+ x:: Union{String, Symbol} , :: Type{Val{counter}} = Val{0 }) where {counter}
384
+ typeof (x)
385
+ end
386
+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
387
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {
388
+ T < :
389
+ Union{Number,
390
+ Symbol,
391
+ String}}
392
+ anyeltypedual (T)
393
+ end
394
+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
395
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {
396
+ T <: Union {
397
+ AbstractArray{
398
+ <: Number ,
399
+ },
400
+ Set{
401
+ <: Number ,
402
+ }}}
403
+ anyeltypedual (eltype (x))
404
+ end
405
+ function anyeltypedual (x:: Union{AbstractArray{T}, Set{T}} ,
406
+ :: Type{Val{counter}} = Val{0 }) where {counter} where {N, T <: NTuple{N, <:Number} }
407
+ anyeltypedual (eltype (x))
408
+ end
409
+
410
+ # Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
411
+ function anyeltypedual (x:: AbstractArray , :: Type{Val{counter}} = Val{0 }) where {counter}
412
+ if isconcretetype (eltype (x))
413
+ anyeltypedual (eltype (x))
414
+ elseif ! isempty (x) && all (i -> isassigned (x, i), 1 : length (x)) &&
415
+ counter < DUALCHECK_RECURSION_MAX
416
+ _counter = Val{counter + 1 }
417
+ mapreduce (y -> anyeltypedual (y, _counter), promote_dual, x)
418
+ else
419
+ # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
420
+ # misses cases of
421
+ Any
422
+ end
423
+ end
424
+
425
+ function anyeltypedual (x:: Set , :: Type{Val{counter}} = Val{0 }) where {counter}
426
+ if isconcretetype (eltype (x))
427
+ anyeltypedual (eltype (x))
428
+ else
429
+ # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
430
+ Any
431
+ end
432
+ end
433
+
434
+ function anyeltypedual (x:: Tuple , :: Type{Val{counter}} = Val{0 }) where {counter}
435
+ # Handle the empty tuple case separately for inference and to avoid mapreduce error
436
+ if x === ()
437
+ Any
438
+ else
439
+ diffeqmapreduce (anyeltypedual, promote_dual, x)
440
+ end
441
+ end
442
+ function anyeltypedual (x:: Dict , :: Type{Val{counter}} = Val{0 }) where {counter}
443
+ isempty (x) ? eltype (values (x)) : mapreduce (anyeltypedual, promote_dual, values (x))
444
+ end
445
+ function anyeltypedual (x:: NamedTuple , :: Type{Val{counter}} = Val{0 }) where {counter}
446
+ anyeltypedual (values (x))
447
+ end
448
+
449
+ function anyeltypedual (
450
+ f:: SciMLBase.AbstractSciMLFunction , :: Type{Val{counter}} ) where {counter}
451
+ Any
452
+ end
453
+
454
+ anyeltypedual (:: @Kwargs {}, :: Type{Val{counter}} = Val{0 }) where {counter} = Any
455
+ anyeltypedual (:: Type{@Kwargs{}} , :: Type{Val{counter}} = Val{0 }) where {counter} = Any
456
+
141
457
# Opt out since these are using for preallocation, not differentiation
142
458
function anyeltypedual (x:: Union{ForwardDiff.AbstractConfig, Module} ,
143
459
:: Type{Val{counter}} = Val{0 }) where {counter}
0 commit comments