From 87235906c600979028dfd09fbe656d103116ee52 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Aug 2024 12:47:23 +0530 Subject: [PATCH 1/6] feat: support `initializeprobpmap` in relevant `SciMLFunction`s --- src/scimlfunctions.jl | 77 +++++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 50370c36d..1128d15a3 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -421,6 +421,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW sys::SYS initializeprob::IProb initializeprobmap::IProbMap + initializeprobpmap::IProbPmap end @doc doc""" @@ -518,7 +519,7 @@ information on generating the SplitFunction from this symbolic engine. struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, - TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -538,6 +539,7 @@ struct SplitFunction{ sys::SYS initializeprob::IProb initializeprobmap::IProbMap + initializeprobpmap::IProbPmap end @doc doc""" @@ -1506,7 +1508,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbMap} <: + SYS, IProb, IProbMap, IProbPmap} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1524,6 +1526,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP sys::SYS initializeprob::IProb initializeprobmap::IProbMap + initializeprobpmap::IProbPmap end """ @@ -2410,7 +2413,8 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing ) where {iip, specialize } @@ -2468,10 +2472,11 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprobpmap) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2481,10 +2486,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix, + analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprobpmap) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2493,11 +2500,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprobpmap) end end @@ -2514,10 +2522,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.initializeprobpmap) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2525,12 +2534,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), - typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap), + typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, f.observed, f.colorvec, f.sys, f.initializeprob, - f.initializeprobmap) + f.initializeprobmap, f.initializeprobpmap) end end @@ -2632,7 +2641,7 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, initializeprobmap) + observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2646,11 +2655,12 @@ end typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprobpmap)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, initializeprobmap) + initializeprob, initializeprobmap, initializeprobpmap) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2680,7 +2690,8 @@ function SplitFunction{iip, specialize}(f1, f2; nothing, sys = __has_sys(f1) ? f1.sys : nothing, initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing + initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing, + initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing ) where {iip, specialize } @@ -2691,11 +2702,12 @@ function SplitFunction{iip, specialize}(f1, f2; if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, + Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, initializeprobmap) + observed, colorvec, sys, initializeprob, initializeprobmap, + initializeprobpmap, initializeprobpmap) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), @@ -2703,11 +2715,12 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprobpmap)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, initializeprobmap) + initializeprob, initializeprobmap, initializeprobpmap) end end @@ -3333,7 +3346,8 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where { + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where { iip, specialize } @@ -3373,21 +3387,22 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprobpmap)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) end end @@ -4331,6 +4346,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) __has_initializeprobmap(f) = isdefined(f, :initializeprobmap) +__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -4349,6 +4365,9 @@ end function has_initializeprobmap(f::AbstractSciMLFunction) __has_initializeprobmap(f) && f.initializeprobmap !== nothing end +function has_initializeprobpmap(f::AbstractSciMLFunction) + __has_initializeprobpmap(f) && f.initializeprobpmap !== nothing +end function has_syms(f::AbstractSciMLFunction) if __has_syms(f) From 02d4a089f7a328acc9b6d7e27ec8fd51333f4eeb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Aug 2024 15:32:58 +0530 Subject: [PATCH 2/6] feat: implement `constructorof` for `ODEProblem` --- src/problems/ode_problems.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index bb228e84d..bc9be9400 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -174,6 +174,17 @@ function Base.setproperty!(prob::ODEProblem, s::Symbol, v, order::Symbol) Base.setfield!(prob, s, v, order) end +function ConstructionBase.constructorof(::Type{P}) where {P <: ODEProblem} + function ctor(f, u0, tspan, p, kw, pt) + if f isa AbstractODEFunction + iip = isinplace(f) + else + iip = isinplace(f, 4) + end + return ODEProblem{iip}(f, u0, tspan, p, pt; kw...) + end +end + """ ODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet()) From 991e559f95e55d0fb405a8bfece7edcdc45a22c6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 Aug 2024 14:16:53 +0530 Subject: [PATCH 3/6] fix: use tspan in `updated_u0_p` for initial values dependent on time --- src/remake.jl | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index fcececb79..f4c112611 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -118,7 +118,7 @@ function remake(prob::ODEProblem; f = missing, tspan = prob.tspan end - newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) iip = isinplace(prob) @@ -214,7 +214,7 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) if problem_type === missing problem_type = prob.problem_type @@ -280,7 +280,7 @@ function remake(prob::SDEProblem; tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) if noise === missing noise = prob.noise @@ -496,35 +496,35 @@ anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() function _updated_u0_p_internal( - prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false) + prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false) return state_values(prob), parameter_values(prob) end function _updated_u0_p_internal( - prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false) + prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false) u0 = state_values(prob) if p isa AbstractArray && isempty(p) return _updated_u0_p_internal( - prob, u0, parameter_values(prob); interpret_symbolicmap) + prob, u0, parameter_values(prob), t0; interpret_symbolicmap) end eltype(p) <: Pair && interpret_symbolicmap || return u0, p defs = default_values(prob) p = fill_p(prob, anydict(p); defs, use_defaults) - return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true)) + return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0) end function _updated_u0_p_internal( - prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false) + prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false) p = parameter_values(prob) eltype(u0) <: Pair || return u0, p defs = default_values(prob) u0 = fill_u0(prob, anydict(u0); defs, use_defaults) - return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false)) + return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0) end function _updated_u0_p_internal( - prob, u0, p; interpret_symbolicmap = true, use_defaults = false) + prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false) isu0symbolic = eltype(u0) <: Pair ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap @@ -538,7 +538,7 @@ function _updated_u0_p_internal( if ispsymbolic p = fill_p(prob, anydict(p); defs, use_defaults) end - return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic)) + return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic), t0) end function fill_u0(prob, u0; defs = nothing, use_defaults = false) @@ -629,7 +629,7 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) return newvals end -function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}) +function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p @@ -642,13 +642,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}) # FIXME: need to provide `u` since the observed function expects it. # This is sort of an implicit dependency on MTK. The values of `u` won't actually be # used, since any state symbols in the expression were substituted out earlier. - temp_state = ProblemState(; u = state_values(prob), p = p) + temp_state = ProblemState(; u = state_values(prob), p = p, t = t0) u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) for (k, v) in u0) return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p end -function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}) +function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) @@ -661,13 +661,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}) # FIXME: need to provide `p` since the observed function expects an `MTKParameters` # this is sort of an implicit dependency on MTK. The values of `p` won't actually be # used, since any parameter symbols in the expression were substituted out earlier. - temp_state = ProblemState(; u = u0, p = parameter_values(prob)) + temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0) p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) for (k, v) in p) return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end -function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}) +function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) @@ -677,11 +677,11 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}) end if !isu0dep u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0)) - return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true)) + return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0) end if !ispdep p = remake_buffer(prob, parameter_values(prob), keys(p), values(p)) - return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false)) + return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0) end varmap = merge(u0, p) @@ -693,7 +693,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}) remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end -function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false) +function updated_u0_p(prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false) if u0 === missing && p === missing return state_values(prob), parameter_values(prob) end @@ -712,7 +712,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = return (u0 === missing ? state_values(prob) : u0), (p === missing ? parameter_values(prob) : p) end - return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults) + return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults) end # overloaded in MTK to intercept symbolic remake From bad69a21770b2eaa1ea7746fa333891100f23dbe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 Aug 2024 14:17:19 +0530 Subject: [PATCH 4/6] refactor: change `remake_initializeprob` to also return `initializeprobpmap` --- src/remake.jl | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index f4c112611..9b9acc14f 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -123,9 +123,8 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) if f === missing - initializeprob, initializeprobmap = remake_initializeprob( - prob.f.sys, prob.f, u0 === missing ? newu0 : u0, - tspan[1], p === missing ? newp : p) + initializeprob, initializeprobmap, initializeprobpmap = remake_initializeprob( + prob.f.sys, prob.f, u0, tspan[1], p) if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) if iip @@ -134,14 +133,14 @@ function remake(prob::ODEProblem; f = missing, unwrapped_f(prob.f.f), (newu0, newu0, newp, ptspan[1])); - initializeprob, initializeprobmap) + initializeprob, initializeprobmap, initializeprobpmap) else _f = ODEFunction{iip, FunctionWrapperSpecialize}( wrapfun_oop( unwrapped_f(prob.f.f), (newu0, newp, ptspan[1])); - initializeprob, initializeprobmap) + initializeprob, initializeprobmap, initializeprobpmap) end else _f = prob.f @@ -159,6 +158,13 @@ function remake(prob::ODEProblem; f = missing, _f = parameterless_type(_f){ iip, specialization(_f), map(typeof, props)...}(props...) end + if __has_initializeprobpmap(_f) + props = getproperties(_f) + @reset props.initializeprobpmap = initializeprobpmap + props = values(props) + _f = parameterless_type(_f){ + iip, specialization(_f), map(typeof, props)...}(props...) + end end elseif f isa AbstractODEFunction _f = f @@ -189,15 +195,19 @@ end remake_initializeprob(sys, scimlfn, u0, t0, p) Re-create the initialization problem present in the function `scimlfn`, using the -associated system `sys`, and the new values of `u0`, initial time `t0` and `p`. By -default, returns `nothing, nothing` if `scimlfn` does not have an initialization -problem, and `scimlfn.initializeprob, scimlfn.initializeprobmap` if it does. +associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and +`p`. By default, returns `nothing, nothing, nothing` if `scimlfn` does not have an +initialization problem, and +`scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap` if it +does. + +Note that `u0` or `p` may be `missing` if the user does not provide a value for them. """ function remake_initializeprob(sys, scimlfn, u0, t0, p) if !has_initializeprob(scimlfn) - return nothing, nothing + return nothing, nothing, nothing end - return scimlfn.initializeprob, scimlfn.initializeprobmap + return scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap end """ From f69ad5a39f3e9bc406a6a1be52a701e000707159 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Sep 2024 15:18:10 +0530 Subject: [PATCH 5/6] feat: add `update_initializeprob!` to relevant SciMLFunctions, `remake` --- src/remake.jl | 27 +++++++++++------ src/scimlfunctions.jl | 70 ++++++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 9b9acc14f..69c28e3e2 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -123,7 +123,7 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) if f === missing - initializeprob, initializeprobmap, initializeprobpmap = remake_initializeprob( + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob( prob.f.sys, prob.f, u0, tspan[1], p) if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) @@ -133,14 +133,14 @@ function remake(prob::ODEProblem; f = missing, unwrapped_f(prob.f.f), (newu0, newu0, newp, ptspan[1])); - initializeprob, initializeprobmap, initializeprobpmap) + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) else _f = ODEFunction{iip, FunctionWrapperSpecialize}( wrapfun_oop( unwrapped_f(prob.f.f), (newu0, newp, ptspan[1])); - initializeprob, initializeprobmap, initializeprobpmap) + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end else _f = prob.f @@ -151,6 +151,13 @@ function remake(prob::ODEProblem; f = missing, _f = parameterless_type(_f){ iip, specialization(_f), map(typeof, props)...}(props...) end + if __has_update_initializeprob!(_f) + props = getproperties(_f) + @reset props.update_initializeprob! = update_initializeprob! + props = values(props) + _f = parameterless_type(_f){ + iip, specialization(_f), map(typeof, props)...}(props...) + end if __has_initializeprobmap(_f) props = getproperties(_f) @reset props.initializeprobmap = initializeprobmap @@ -196,18 +203,19 @@ end Re-create the initialization problem present in the function `scimlfn`, using the associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and -`p`. By default, returns `nothing, nothing, nothing` if `scimlfn` does not have an +`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an initialization problem, and -`scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap` if it -does. +`scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap` +if it does. Note that `u0` or `p` may be `missing` if the user does not provide a value for them. """ function remake_initializeprob(sys, scimlfn, u0, t0, p) if !has_initializeprob(scimlfn) - return nothing, nothing, nothing + return nothing, nothing, nothing, nothing end - return scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap + return scimlfn.initializeprob, + scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap end """ @@ -703,7 +711,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end -function updated_u0_p(prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false) +function updated_u0_p( + prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false) if u0 === missing && p === missing return state_values(prob), parameter_values(prob) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1128d15a3..0b41d59d9 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} + SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -420,6 +420,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW colorvec::TCV sys::SYS initializeprob::IProb + update_initializeprob!::UIProb initializeprobmap::IProbMap initializeprobpmap::IProbPmap end @@ -519,7 +520,7 @@ information on generating the SplitFunction from this symbolic engine. struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, - TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} + TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -538,6 +539,7 @@ struct SplitFunction{ colorvec::TCV sys::SYS initializeprob::IProb + update_initializeprob!::UIProb initializeprobmap::IProbMap initializeprobpmap::IProbPmap end @@ -1508,7 +1510,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbMap, IProbPmap} <: + SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1525,6 +1527,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP colorvec::TCV sys::SYS initializeprob::IProb + update_initializeprob!::UIProb initializeprobmap::IProbMap initializeprobpmap::IProbPmap end @@ -2413,6 +2416,8 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + update_initializeprob! = __has_update_initializeprob!(f) ? + f.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing ) where {iip, @@ -2472,10 +2477,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, + observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, @@ -2485,12 +2490,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), + typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, + observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) else ODEFunction{iip, specialize, @@ -2500,11 +2505,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), + typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, + observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end end @@ -2522,10 +2528,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, Any, Any}( + newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.observed, f.colorvec, f.sys, f.initializeprob, + f.update_initializeprob!, f.initializeprobmap, f.initializeprobpmap) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), @@ -2534,11 +2542,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap), + typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!), + typeof(f.initializeprobmap), typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, + f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!, f.initializeprobmap, f.initializeprobpmap) end end @@ -2641,7 +2650,8 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) + observed, colorvec, sys, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2655,12 +2665,12 @@ end typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), typeof(initializeprobpmap)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, initializeprobmap, initializeprobpmap) + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2690,6 +2700,8 @@ function SplitFunction{iip, specialize}(f1, f2; nothing, sys = __has_sys(f1) ? f1.sys : nothing, initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing, + update_initializeprob! = __has_update_initializeprob!(f1) ? + f1.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing ) where {iip, @@ -2701,12 +2713,12 @@ function SplitFunction{iip, specialize}(f1, f2; if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, initializeprobmap, + observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap, initializeprobpmap, initializeprobpmap) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), @@ -2715,12 +2727,13 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), + typeof(initializeprobmap), typeof(initializeprobpmap)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, initializeprobmap, initializeprobpmap) + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end end @@ -3346,6 +3359,8 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + update_initializeprob! = __has_update_initializeprob!(f) ? + f.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where { iip, @@ -3387,22 +3402,25 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) + _colorvec, sys, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), + typeof(initializeprobmap), typeof(initializeprobpmap)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap) + _colorvec, sys, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) end end @@ -4345,6 +4363,7 @@ __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) +__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!) __has_initializeprobmap(f) = isdefined(f, :initializeprobmap) __has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap) @@ -4362,6 +4381,9 @@ has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing function has_initializeprob(f::AbstractSciMLFunction) __has_initializeprob(f) && f.initializeprob !== nothing end +function has_update_initializeprob!(f::AbstractSciMLFunction) + __has_update_initializeprob!(f) && f.update_initializeprob! !== nothing +end function has_initializeprobmap(f::AbstractSciMLFunction) __has_initializeprobmap(f) && f.initializeprobmap !== nothing end From a25ee6aae7160dd4b10b57fcf10f78ef02e94061 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 26 Sep 2024 15:49:13 +0530 Subject: [PATCH 6/6] feat: add `build_initializeprob = true` flag to `remake(::ODEProblem)` --- src/remake.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 69c28e3e2..6de145f8d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -112,6 +112,7 @@ function remake(prob::ODEProblem; f = missing, p = missing, kwargs = missing, interpret_symbolicmap = true, + build_initializeprob = true, use_defaults = false, _kwargs...) if tspan === missing @@ -123,8 +124,12 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) if f === missing - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob( - prob.f.sys, prob.f, u0, tspan[1], p) + if build_initializeprob + initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob( + prob.f.sys, prob.f, u0, tspan[1], p) + else + initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing + end if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) if iip