@@ -3,13 +3,16 @@ module SciMLBaseForwardDiffExt
33using SciMLBase, ForwardDiff
44using ArrayInterface
55
6- import SciMLBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, isdualtype, value
6+ import SciMLBase:
7+ wrapfun_oop, wrapfun_iip, isdualtype, value, DualEltypeChecker,
8+ AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem,
9+ ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem,
10+ RecursiveArrayTools, totallength
11+
712
813
914eltypedual (x) = eltype (x) <: ForwardDiff.Dual
1015isdualtype (:: Type{<:ForwardDiff.Dual} ) = true
11- const dualT = ForwardDiff. Dual{ForwardDiff. Tag{OrdinaryDiffEqTag, Float64}, Float64, 1 }
12- dualgen (:: Type{T} ) where {T} = ForwardDiff. Dual{ForwardDiff. Tag{OrdinaryDiffEqTag, T}, T, 1 }
1316
1417# Copy of the other prob2dtmin dispatch, just for optionality
1518function prob2dtmin (tspan, :: ForwardDiff.Dual , use_end_time)
@@ -22,91 +25,16 @@ function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
2225 end
2326end
2427
25- function hasdualpromote (u0, t:: Number )
26- hasmethod (ArrayInterface. promote_eltype,
27- Tuple{Type{typeof (u0)}, Type{dualgen (eltype (u0))}}) &&
28- hasmethod (promote_rule,
29- Tuple{Type{eltype (u0)}, Type{dualgen (eltype (u0))}}) &&
30- hasmethod (promote_rule,
31- Tuple{Type{eltype (u0)}, Type{typeof (t)}})
32- end
33-
34- const NORECOMPILE_IIP_SUPPORTED_ARGS = (
35- Tuple{Vector{Float64}, Vector{Float64},
36- Vector{Float64}, Float64},
37- Tuple{Vector{Float64}, Vector{Float64},
38- SciMLBase. NullParameters, Float64})
39-
40- const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64},
41- Tuple{Vector{Float64}, SciMLBase. NullParameters, Float64},
42- Tuple{Vector{Float64}, Vector{Float64}, dualT},
43- Tuple{Vector{dualT}, Vector{Float64}, Float64},
44- Tuple{Vector{dualT}, SciMLBase. NullParameters, Float64},
45- Tuple{Vector{Float64}, SciMLBase. NullParameters, dualT})
46-
47- const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64},
48- Vector{Float64}, Float64},
49- Tuple{Vector{Float64},
50- SciMLBase. NullParameters, Float64})
51- const oop_returnlists = (Vector{Float64}, Vector{Float64},
52- ntuple (x -> Vector{dualT}, length (oop_arglists) - 2 )... )
53-
54- function wrapfun_oop (ff, inputs:: Tuple = ())
55- if ! isempty (inputs)
56- IT = Tuple{map (typeof, inputs)... }
57- if IT ∉ NORECOMPILE_OOP_SUPPORTED_ARGS
58- throw (NoRecompileArgumentError (IT))
59- end
60- end
61- FunctionWrappersWrappers. FunctionWrappersWrapper (ff, oop_arglists,
62- oop_returnlists)
63- end
64-
65- function wrapfun_iip (ff,
66- inputs:: Tuple{T1, T2, T3, T4} ) where {T1, T2, T3, T4}
67- T = eltype (T2)
68- dualT = dualgen (T)
69- dualT1 = ArrayInterface. promote_eltype (T1, dualT)
70- dualT2 = ArrayInterface. promote_eltype (T2, dualT)
71- dualT4 = dualgen (promote_type (T, T4))
28+ # function hasdualpromote(u0, t::Number)
29+ # hasmethod(ArrayInterface.promote_eltype,
30+ # Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
31+ # hasmethod(promote_rule,
32+ # Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
33+ # hasmethod(promote_rule,
34+ # Tuple{Type{eltype(u0)}, Type{typeof(t)}})
35+ # end
7236
73- iip_arglists = (Tuple{T1, T2, T3, T4},
74- Tuple{dualT1, dualT2, T3, T4},
75- Tuple{dualT1, T2, T3, dualT4},
76- Tuple{dualT1, dualT2, T3, dualT4})
7737
78- iip_returnlists = ntuple (x -> Nothing, 4 )
79-
80- fwt = map (iip_arglists, iip_returnlists) do A, R
81- FunctionWrappersWrappers. FunctionWrappers. FunctionWrapper {R, A} (Void (ff))
82- end
83- FunctionWrappersWrappers. FunctionWrappersWrapper {typeof(fwt), false} (fwt)
84- end
85-
86- const iip_arglists_default = (
87- Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64},
88- Float64},
89- Tuple{Vector{Float64}, Vector{Float64},
90- SciMLBase. NullParameters,
91- Float64
92- },
93- Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
94- Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT},
95- Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
96- Tuple{Vector{dualT}, Vector{dualT}, SciMLBase. NullParameters,
97- Float64
98- },
99- Tuple{Vector{dualT}, Vector{Float64},
100- SciMLBase. NullParameters, dualT
101- })
102- const iip_returnlists_default = ntuple (x -> Nothing, length (iip_arglists_default))
103-
104- function wrapfun_iip (@nospecialize (ff))
105- fwt = map (iip_arglists_default, iip_returnlists_default) do A, R
106- FunctionWrappersWrappers. FunctionWrappers. FunctionWrapper {R, A} (Void (ff))
107- end
108- FunctionWrappersWrappers. FunctionWrappersWrapper {typeof(fwt), false} (fwt)
109- end
11038
11139promote_dual (:: Type{T} , :: Type{T2} ) where {T <: ForwardDiff.Dual , T2} = T
11240function promote_dual (:: Type{T} ,
@@ -497,9 +425,9 @@ unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
497425unitfulvalue (x:: ForwardDiff.Dual ) = unitfulvalue (ForwardDiff. value (x))
498426
499427sse (x:: ForwardDiff.Dual ) = sse (ForwardDiff. value (x)) + sum (sse, ForwardDiff. partials (x))
500- function DiffEqBase . totallength (x:: ForwardDiff.Dual )
501- return DiffEqBase . totallength (ForwardDiff. value (x)) +
502- sum (DiffEqBase . totallength, ForwardDiff. partials (x))
428+ function SciMLBase . totallength (x:: ForwardDiff.Dual )
429+ return SciMLBase . totallength (ForwardDiff. value (x)) +
430+ sum (SciMLBase . totallength, ForwardDiff. partials (x))
503431end
504432
505433end
0 commit comments