Skip to content

Commit b0dc015

Browse files
Merge pull request #891 from AayushSabharwal/as/fix-mtk-tests
feat: add new `remake(::AbstractSciMLFunction)`, fix some `remake` bugs.
2 parents 97a79f7 + d455a5a commit b0dc015

File tree

5 files changed

+142
-115
lines changed

5 files changed

+142
-115
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import FunctionWrappersWrappers
2222
import RuntimeGeneratedFunctions
2323
import EnumX
2424
import ADTypes: ADTypes, AbstractADType
25-
import Accessors: @set, @reset, @delete
25+
import Accessors: @set, @reset, @delete, @insert
2626
using Expronicon.ADT: @match
2727

2828
using Reexport

src/initialization.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ function evaluate_f(
124124
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
125125
end
126126

127-
function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
127+
function evaluate_f(
128+
integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
128129
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
129130
end
130131

131-
function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
132+
function evaluate_f(integrator::AbstractSDDEIntegrator,
133+
prob::AbstractSDDEProblem, f, isinplace, u, p, t)
132134
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
133135
end
134136

src/remake.jl

Lines changed: 112 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,105 @@ function remake(
9999
_remake_internal(prob; kwargs..., p)
100100
end
101101

102+
"""
103+
$(TYPEDSIGNATURES)
104+
105+
A utility function which merges two `NamedTuple`s `a` and `b`, assuming that the
106+
keys of `a` are a subset of those of `b`. Values in `b` take priority over those
107+
in `a`, except if they are `nothing`. Keys not present in `a` are assumed to have
108+
a value of `nothing`.
109+
"""
110+
function _similar_namedtuple_merge_ignore_nothing(a::NamedTuple, b::NamedTuple)
111+
ks = fieldnames(typeof(b))
112+
return NamedTuple{ks}(ntuple(Val(length(ks))) do i
113+
something(get(b, ks[i], nothing), get(a, ks[i], nothing), Some(nothing))
114+
end)
115+
end
116+
117+
"""
118+
remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
119+
120+
`remake` the given `func`. Return an `AbstractSciMLFunction` of the same kind, `isinplace` and
121+
`specialization` as `func`. Retain the properties of `func`, except those that are overridden
122+
by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument
123+
is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword
124+
argument is used to override `func.f2`, and `f` is used for `func.f1`. If
125+
`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will
126+
override those of `func` (but not ones provided via keyword arguments). Properties of `f` that
127+
are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If
128+
`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be
129+
of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and
130+
`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace`
131+
and `specialization` as `func`.
132+
"""
133+
function remake(
134+
func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
135+
# retain iip and spec of original function
136+
iip = isinplace(func)
137+
spec = specialization(func)
138+
# retain properties of original function
139+
props = getproperties(func)
140+
141+
if f === missing || is_split_function(func)
142+
# if no `f` is provided, create the same type of SciMLFunction
143+
T = parameterless_type(func)
144+
f = isdefined(func, :f) ? func.f : func.f1
145+
elseif f isa AbstractSciMLFunction
146+
# if `f` is a SciMLFunction, create that type
147+
T = parameterless_type(f)
148+
# properties of `f` take priority over those in the existing `func`
149+
# ignore properties of `f` which are `nothing` but present in `func`
150+
props = _similar_namedtuple_merge_ignore_nothing(props, getproperties(f))
151+
f = isdefined(f, :f) ? f.f : f.f1
152+
else
153+
# if `f` is provided but not a SciMLFunction, create the same type
154+
T = parameterless_type(func)
155+
end
156+
157+
# minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data`
158+
# it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap
159+
# in this case and forget about properties.
160+
while !is_split_function(T) && f isa AbstractSciMLFunction
161+
f = isdefined(f, :f) ? f.f : f.f1
162+
end
163+
164+
props = @delete props.f
165+
props = @delete props.f1
166+
167+
args = (f,)
168+
if is_split_function(T)
169+
# for DynamicalSDEFunction and SplitFunction
170+
if isdefined(props, :cache)
171+
props = @insert props._func_cache = props.cache
172+
props = @delete props.cache
173+
end
174+
175+
# `f1` and `f2` are wrapped in another SciMLFunction, unless they're
176+
# already wrapped in the appropriate type or are an `AbstractSciMLOperator`
177+
if !(f isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
178+
f = split_function_f_wrapper(T){iip, spec}(f)
179+
end
180+
# For SplitFunction
181+
# we don't do the same thing as `g`, because for SDEs `g` is
182+
# stored in the problem as well, whereas for Split ODEs etc
183+
# f2 is a part of the function. Thus, if the user provides
184+
# a SciMLFunction for `f` which contains `f2` we use that.
185+
f2 = coalesce(f2, get(props, :f2, missing), func.f2)
186+
if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
187+
f2 = split_function_f_wrapper(T){iip, spec}(f2)
188+
end
189+
props = @delete props.f2
190+
args = (args..., f2)
191+
end
192+
if isdefined(func, :g)
193+
# For SDEs/SDDEs where `g` is not a keyword
194+
g = coalesce(g, func.g)
195+
props = @delete props.g
196+
args = (args..., g)
197+
end
198+
T{iip, spec}(args...; props..., kwargs...)
199+
end
200+
102201
"""
103202
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
104203
p = missing, kwargs = missing, _kwargs...)
@@ -135,53 +234,26 @@ function remake(prob::ODEProblem; f = missing,
135234
initialization_data = nothing
136235
end
137236

138-
if f === missing
139-
if specialization(prob.f) === FunctionWrapperSpecialize
140-
ptspan = promote_tspan(tspan)
141-
if iip
142-
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
143-
wrapfun_iip(
144-
unwrapped_f(prob.f.f),
145-
(newu0, newu0, newp,
146-
ptspan[1])); initialization_data)
147-
else
148-
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
149-
wrapfun_oop(
150-
unwrapped_f(prob.f.f),
151-
(newu0, newp,
152-
ptspan[1])); initialization_data)
153-
end
154-
else
155-
_f = prob.f
156-
if __has_initialization_data(_f)
157-
props = getproperties(_f)
158-
@reset props.initialization_data = initialization_data
159-
props = values(props)
160-
_f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...)
161-
end
162-
end
163-
elseif f isa AbstractODEFunction
164-
_f = f
165-
elseif specialization(prob.f) === FunctionWrapperSpecialize
237+
f = coalesce(f, prob.f)
238+
f = remake(prob.f; f, initialization_data)
239+
240+
if specialization(f) === FunctionWrapperSpecialize
166241
ptspan = promote_tspan(tspan)
167242
if iip
168-
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
169-
(newu0, newu0, newp,
170-
ptspan[1])))
243+
f = remake(
244+
f; f = wrapfun_iip(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
171245
else
172-
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
173-
(newu0, newp, ptspan[1])))
246+
f = remake(
247+
f; f = wrapfun_oop(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
174248
end
175-
else
176-
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
177249
end
178250

179251
prob = if kwargs === missing
180-
ODEProblem{isinplace(prob)}(
181-
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
252+
ODEProblem{iip}(
253+
f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
182254
_kwargs...)
183255
else
184-
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
256+
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
185257
end
186258

187259
if lazy_initialization === nothing
@@ -395,42 +467,6 @@ function remake(prob::SDEProblem;
395467
return prob
396468
end
397469

398-
"""
399-
remake(func::SDEFunction; f = missing, g = missing,
400-
mass_matrix = missing, analytic = missing, kwargs...)
401-
402-
Remake the given `SDEFunction`.
403-
"""
404-
function remake(func::Union{SDEFunction, SDDEFunction};
405-
f = missing,
406-
g = missing,
407-
mass_matrix = missing,
408-
analytic = missing,
409-
sys = missing,
410-
kwargs...)
411-
props = getproperties(func)
412-
props = @delete props.f
413-
props = @delete props.g
414-
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
415-
@reset props.analytic = coalesce(analytic, func.analytic)
416-
@reset props.sys = coalesce(sys, func.sys)
417-
418-
if f === missing
419-
f = func.f
420-
end
421-
422-
if g === missing
423-
g = func.g
424-
end
425-
426-
if f isa AbstractSciMLFunction
427-
f = f.f
428-
end
429-
430-
T = func isa SDEFunction ? SDEFunction : SDDEFunction
431-
return T{isinplace(func)}(f, g; props..., kwargs...)
432-
end
433-
434470
function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
435471
tspan = missing, p = missing, constant_lags = missing,
436472
dependent_lags = missing, order_discontinuity_t0 = missing,
@@ -497,28 +533,6 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
497533
return prob
498534
end
499535

500-
function remake(func::DDEFunction;
501-
f = missing,
502-
mass_matrix = missing,
503-
analytic = missing,
504-
sys = missing,
505-
kwargs...)
506-
props = getproperties(func)
507-
props = @delete props.f
508-
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
509-
@reset props.analytic = coalesce(analytic, func.analytic)
510-
@reset props.sys = coalesce(sys, func.sys)
511-
512-
if f === missing
513-
f = func.f
514-
end
515-
if f isa AbstractSciMLFunction
516-
f = f.f
517-
end
518-
519-
return DDEFunction{isinplace(func)}(f; props..., kwargs...)
520-
end
521-
522536
function remake(prob::SDDEProblem;
523537
f = missing,
524538
g = missing,
@@ -706,6 +720,7 @@ function remake(prob::NonlinearProblem;
706720
initialization_data = nothing
707721
end
708722

723+
f = coalesce(f, prob.f)
709724
f = remake(prob.f; f, initialization_data)
710725

711726
if problem_type === missing
@@ -737,22 +752,6 @@ function remake(prob::NonlinearProblem;
737752
return prob
738753
end
739754

740-
function remake(func::NonlinearFunction;
741-
f = missing,
742-
kwargs...)
743-
props = getproperties(func)
744-
props = @delete props.f
745-
746-
if f === missing
747-
f = func.f
748-
end
749-
if f isa AbstractSciMLFunction
750-
f = f.f
751-
end
752-
753-
return NonlinearFunction{isinplace(func)}(f; props..., kwargs...)
754-
end
755-
756755
"""
757756
remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
758757
kwargs = missing, _kwargs...)
@@ -775,6 +774,7 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
775774
initialization_data = nothing
776775
end
777776

777+
f = coalesce(f, prob.f)
778778
f = remake(prob.f; f, initialization_data)
779779

780780
prob = if kwargs === missing

src/scimlfunctions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4601,6 +4601,20 @@ has_Wfact_t(f::JacobianWrapper) = has_Wfact_t(f.f)
46014601
has_paramjac(f::JacobianWrapper) = has_paramjac(f.f)
46024602
has_colorvec(f::JacobianWrapper) = has_colorvec(f.f)
46034603

4604+
is_split_function(x) = is_split_function(typeof(x))
4605+
is_split_function(::Type) = false
4606+
function is_split_function(::Type{T}) where {T <: Union{
4607+
SplitFunction, SplitSDEFunction, DynamicalODEFunction,
4608+
DynamicalDDEFunction, DynamicalSDEFunction}}
4609+
true
4610+
end
4611+
4612+
split_function_f_wrapper(::Type{<:SplitFunction}) = ODEFunction
4613+
split_function_f_wrapper(::Type{<:SplitSDEFunction}) = SDEFunction
4614+
split_function_f_wrapper(::Type{<:DynamicalODEFunction}) = ODEFunction
4615+
split_function_f_wrapper(::Type{<:DynamicalDDEFunction}) = DDEFunction
4616+
split_function_f_wrapper(::Type{<:DynamicalSDEFunction}) = DDEFunction
4617+
46044618
######### Additional traits
46054619

46064620
islinear(::AbstractDiffEqFunction) = false

test/remake_tests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,14 @@ end
372372
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
373373
@test_nowarn remake(prob; u0 = [:x => nothing], p = [:a => nothing])
374374
end
375+
376+
@testset "retain properties of `SciMLFunction` passed to `remake`" begin
377+
u0 = [1.0; 2.0; 3.0]
378+
p = [10.0, 20.0, 30.0]
379+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
380+
fn = NonlinearFunction(nllorenz!; sys, resid_prototype = zeros(Float64, 3))
381+
prob = NonlinearProblem(fn, u0, p)
382+
fn2 = NonlinearFunction(nllorenz!; resid_prototype = zeros(Float32, 3))
383+
prob2 = remake(prob; f = fn2)
384+
@test prob2.f.resid_prototype isa Vector{Float32}
385+
end

0 commit comments

Comments
 (0)