Skip to content

Commit 9ea1147

Browse files
committed
feat: use specialization in NonlinearProblems
1 parent 9de748d commit 9ea1147

File tree

3 files changed

+69
-25
lines changed

3 files changed

+69
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.159.0"
4+
version = "6.160.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/norecompile.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ function wrapfun_iip(ff,
7373
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
7474
end
7575

76+
function wrapfun_iip(ff,
77+
inputs::Tuple{T1, T2, T3}) where {T1, T2, T3}
78+
T = eltype(T2)
79+
dualT = dualgen(T)
80+
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
81+
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
82+
83+
iip_arglists = (Tuple{T1, T2, T3},
84+
Tuple{dualT1, dualT2, T3},
85+
Tuple{dualT1, T2, T3})
86+
87+
iip_returnlists = ntuple(x -> Nothing, 3)
88+
89+
fwt = map(iip_arglists, iip_returnlists) do A, R
90+
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
91+
end
92+
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
93+
end
94+
7695
const iip_arglists_default = (
7796
Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64},
7897
Float64},

src/solve.jl

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,14 +1152,20 @@ function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
11521152
p = get_concrete_p(prob, kwargs)
11531153
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
11541154
u0 = promote_u0(u0, p, nothing)
1155-
remake(prob; u0 = u0, p = p)
1155+
f_promote = promote_f(
1156+
prob.f, Val(SciMLBase.specialization(prob.f)), u0, p
1157+
)
1158+
remake(prob; u0 = u0, p = p, f = f_promote)
11561159
end
11571160

11581161
function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...)
11591162
p = get_concrete_p(prob, kwargs)
11601163
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
11611164
u0 = promote_u0(u0, p, nothing)
1162-
remake(prob; u0 = u0, p = p)
1165+
f_promote = promote_f(
1166+
prob.f, Val(SciMLBase.specialization(prob.f)), u0, p
1167+
)
1168+
remake(prob; u0 = u0, p = p, f = f_promote)
11631169
end
11641170

11651171
function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...)
@@ -1252,28 +1258,47 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize}
12521258
f = @set f.jac_prototype = similar(f.jac_prototype, uElType)
12531259
end
12541260

1255-
@static if VERSION >= v"1.8-"
1256-
f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) &&
1257-
# Some reinitialization code still uses NLSolvers stuff which doesn't
1258-
# properly tag, so opt-out if potentially a mass matrix DAE
1259-
f.mass_matrix isa UniformScaling &&
1260-
# Jacobians don't wrap, so just ignore those cases
1261-
f.jac === nothing &&
1262-
((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any &&
1263-
RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) &&
1264-
one(t) === oneunit(t) &&
1265-
hasmethod(ArrayInterface.promote_eltype,
1266-
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
1267-
hasmethod(promote_rule,
1268-
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
1269-
hasmethod(promote_rule,
1270-
Tuple{Type{eltype(u0)}, Type{typeof(t)}})) ||
1271-
(specialize === SciMLBase.FunctionWrapperSpecialize &&
1272-
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)))
1273-
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
1274-
else
1275-
return f
1276-
end
1261+
f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) &&
1262+
# Some reinitialization code still uses NLSolvers stuff which doesn't
1263+
# properly tag, so opt-out if potentially a mass matrix DAE
1264+
f.mass_matrix isa UniformScaling &&
1265+
# Jacobians don't wrap, so just ignore those cases
1266+
f.jac === nothing &&
1267+
((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any &&
1268+
RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) &&
1269+
one(t) === oneunit(t) &&
1270+
hasmethod(ArrayInterface.promote_eltype,
1271+
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
1272+
hasmethod(promote_rule,
1273+
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
1274+
hasmethod(promote_rule,
1275+
Tuple{Type{eltype(u0)}, Type{typeof(t)}})) ||
1276+
(specialize === SciMLBase.FunctionWrapperSpecialize &&
1277+
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)))
1278+
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
1279+
else
1280+
return f
1281+
end
1282+
end
1283+
1284+
function promote_f(f::NonlinearFunction, ::Val{specialize}, u0, p) where {specialize}
1285+
# Ensure our jacobian will be of the same type as u0
1286+
uElType = u0 === nothing ? Float64 : eltype(u0)
1287+
if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray
1288+
f = @set f.jac_prototype = similar(f.jac_prototype, uElType)
1289+
end
1290+
1291+
f = if isinplace(f) && !(f.f isa AbstractSciMLOperator) &&
1292+
f.jac === nothing &&
1293+
((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any &&
1294+
RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) &&
1295+
hasmethod(ArrayInterface.promote_eltype,
1296+
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
1297+
hasmethod(promote_rule,
1298+
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}})) ||
1299+
(specialize === SciMLBase.FunctionWrapperSpecialize &&
1300+
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)))
1301+
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p)))
12771302
else
12781303
return f
12791304
end

0 commit comments

Comments
 (0)