From de94f4e7d27cba8fc6fd1ddfedf0d5d0853e830f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 14:29:55 +0530 Subject: [PATCH 1/7] feat: add `OverrideInitData` --- src/SciMLBase.jl | 1 + src/initialization.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 src/initialization.jl diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index c8c8d0b66..ea78cc02a 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -654,6 +654,7 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context. struct TrackerOriginator <: ADOriginator end include("utils.jl") +include("initialization.jl") include("function_wrappers.jl") include("scimlfunctions.jl") include("alg_traits.jl") diff --git a/src/initialization.jl b/src/initialization.jl new file mode 100644 index 000000000..6cbbcda72 --- /dev/null +++ b/src/initialization.jl @@ -0,0 +1,26 @@ +""" + $(TYPEDEF) + +A collection of all the data required for `OverrideInit`. +""" +struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} + """ + The `AbstractNonlinearProblem` to solve for initialization. + """ + initializeprob::IProb + """ + A function which takes `(initializeprob, prob)` and updates + the parameters of the former with their values in the latter. + """ + update_initializeprob!::UIProb + """ + A function which takes the solution of `initializeprob` and returns + the state vector of the original problem. + """ + initializeprobmap::IProbMap + """ + A function which takes the solution of `initializeprob` and returns + the parameter object of the original problem. + """ + initializeprobpmap::IProbPmap +end From 19d46e67e09c69c51899bb6daca62c2376577a36 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 14:30:29 +0530 Subject: [PATCH 2/7] refactor: use `initialization_data` instead of `initializeprob`, etc. --- src/scimlfunctions.jl | 162 ++++++++++++++++++++++-------------------- 1 file changed, 85 insertions(+), 77 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index b05ffbfd5..6058d7e0d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -405,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, - O, TCV, SYS, - IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip} + O, TCV, + SYS, ID, NLP} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -423,10 +423,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap + initialization_data::ID nlprob::NLP end @@ -530,8 +527,8 @@ information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt, - TPJ, O, TCV, SYS, - IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip} + TPJ, O, + TCV, SYS, ID, NLP} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -550,10 +547,7 @@ struct SplitFunction{ observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap + initialization_data::ID nlprob::NLP end @@ -1529,7 +1523,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, UIProb, IProbMap, IProbPmap} <: + SYS, ID} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1545,10 +1539,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap + initialization_data::ID end """ @@ -2440,6 +2431,8 @@ function ODEFunction{iip, specialize}(f; initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, nlprob = __has_nlprob(f) ? f.nlprob : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing ) where {iip, specialize } @@ -2486,8 +2479,11 @@ function ODEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) - @assert typeof(initializeprob) <: + @assert typeof(initdata.initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} if specialize === NoSpecialize @@ -2497,11 +2493,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2510,16 +2505,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), - typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), - typeof(nlprob)}(_f, mass_matrix, + typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2528,14 +2518,10 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), - typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac, - jvp, vjp, jac_prototype, sparsity, Wfact, + typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, + jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) end end @@ -2552,13 +2538,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any, Any, Any, Any}( + typeof(f.sys), Any, Any}( newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, - f.update_initializeprob!, f.initializeprobmap, - f.initializeprobpmap, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2566,14 +2550,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!), - typeof(f.initializeprobmap), - typeof(f.initializeprobpmap), - typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}( + newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!, - f.initializeprobmap, f.initializeprobpmap, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) end end @@ -2704,8 +2685,8 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap, nlprob) + observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing, + initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2714,17 +2695,20 @@ end throw(NonconformingFunctionsError(["f2"])) end + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) + SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), - typeof(initializeprobpmap), typeof(nlprob)}( + typeof(sys), typeof(initdata), typeof(nlprob)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob) + initdata, nlprob) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2761,23 +2745,27 @@ function SplitFunction{iip, specialize}(f1, f2; f1.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing, - nlprob = __has_nlprob(f1) ? f1.nlprob : nothing + nlprob = __has_nlprob(f1) ? f1.nlprob : nothing, + initialization_data = __has_initialization_data(f1) ? f1.initialization_data : + nothing ) where {iip, specialize } sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) - @assert typeof(initializeprob) <: + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) + @assert typeof(initdata.initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, + Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap, - initializeprobpmap, initializeprobpmap, nlprob) + observed, colorvec, sys, initdata, nlprob) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), @@ -2785,13 +2773,11 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), typeof(nlprob)}(f1, f2, + typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob) + initdata, nlprob) end end @@ -3420,7 +3406,9 @@ function DAEFunction{iip, specialize}(f; update_initializeprob! = __has_update_initializeprob!(f) ? f.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, - initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where { + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing) where { iip, specialize } @@ -3452,33 +3440,32 @@ function DAEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) - @assert typeof(initializeprob) <: + @assert typeof(initdata.initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap) + _colorvec, sys, initdata) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap)}( + typeof(sys), typeof(initdata)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap) + _colorvec, sys, initdata) end end @@ -4397,6 +4384,14 @@ function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing) return sys end +function reconstruct_initialization_data( + initdata, initprob, update_initprob!, initprobmap, initprobpmap) + if initdata === nothing && initprob !== nothing + initdata = OverrideInitData(initprob, update_initprob!, initprobmap, initprobpmap) + end + return initdata +end + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -4420,11 +4415,20 @@ __has_colorvec(f) = isdefined(f, :colorvec) __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) -__has_initializeprob(f) = isdefined(f, :initializeprob) -__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!) -__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) -__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap) __has_nlprob(f) = isdefined(f, :nlprob) +function __has_initializeprob(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob) +end +function __has_update_initializeprob!(f) + has_initialization_data(f) && isdefined(f.initialization_data, :update_initializeprob!) +end +function __has_initializeprobmap(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobmap) +end +function __has_initializeprobpmap(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobpmap) +end +__has_initialization_data(f) = isdefined(f, :initialization_data) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -4438,16 +4442,20 @@ has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothin has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing function has_initializeprob(f::AbstractSciMLFunction) - __has_initializeprob(f) && f.initializeprob !== nothing + __has_initializeprob(f) && f.initialization_data.initializeprob !== nothing end function has_update_initializeprob!(f::AbstractSciMLFunction) - __has_update_initializeprob!(f) && f.update_initializeprob! !== nothing + __has_update_initializeprob!(f) && + f.initialization_data.update_initializeprob! !== nothing end function has_initializeprobmap(f::AbstractSciMLFunction) - __has_initializeprobmap(f) && f.initializeprobmap !== nothing + __has_initializeprobmap(f) && f.initialization_data.initializeprobmap !== nothing end function has_initializeprobpmap(f::AbstractSciMLFunction) - __has_initializeprobpmap(f) && f.initializeprobpmap !== nothing + __has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing +end +function has_initialization_data(f::AbstractSciMLFunction) + __has_initialization_data(f) && f.initialization_data !== nothing end function has_syms(f::AbstractSciMLFunction) From 6ee8151668a3d107716cb84628a13c0ccb5c3957 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 15:17:28 +0530 Subject: [PATCH 3/7] fix: add `getproperty` method for backwards compatibility --- src/scimlfunctions.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 6058d7e0d..eccce0deb 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4606,3 +4606,14 @@ function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym::Symb end SymbolicIndexingInterface.constant_structure(::AbstractSciMLFunction) = true + +function Base.getproperty(x::Union{ODEFunction, SplitFunction, DAEFunction}, sym::Symbol) + if sym == :initializeprob || sym == :update_initializeprob! || sym == :initializeprobmap || sym == :initializeprobpmap + if x.initialization_data === nothing + return nothing + else + return getproperty(x.initialization_data, sym) + end + end + return getfield(x, sym) +end From 03aa8a2b63edd4a7f56bf6bcde4dd6841e24120a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 16:21:05 +0530 Subject: [PATCH 4/7] fix: reorder `SplitFunction` fields to avoid breaking change The existing syntax is `initializeprob, ..., nlprob`. Trying to add `initialization_data` in the middle breaks the non-kwarg-only method. Putting it at the end fixes this issue. Anything old still has the order it relies on, with `initialization_data` defaulting to `nothing`, and anything new would just have to provide the redundant kwargs if it needs to specify `initialization_data`. --- src/scimlfunctions.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index eccce0deb..027bdacc1 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -547,8 +547,8 @@ struct SplitFunction{ observed::O colorvec::TCV sys::SYS - initialization_data::ID nlprob::NLP + initialization_data::ID end @doc doc""" @@ -2518,7 +2518,8 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, + typeof(sys), typeof(initdata), typeof(nlprob)}( + _f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, observed, _colorvec, sys, initdata, nlprob) @@ -2686,7 +2687,7 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing, - initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob) + initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -4608,12 +4609,13 @@ end SymbolicIndexingInterface.constant_structure(::AbstractSciMLFunction) = true function Base.getproperty(x::Union{ODEFunction, SplitFunction, DAEFunction}, sym::Symbol) - if sym == :initializeprob || sym == :update_initializeprob! || sym == :initializeprobmap || sym == :initializeprobpmap - if x.initialization_data === nothing - return nothing - else - return getproperty(x.initialization_data, sym) + if sym == :initializeprob || sym == :update_initializeprob! || + sym == :initializeprobmap || sym == :initializeprobpmap + if x.initialization_data === nothing + return nothing + else + return getproperty(x.initialization_data, sym) + end end - end - return getfield(x, sym) + return getfield(x, sym) end From 0df5804ed042675231775a1c50cc49cbaeb0cabf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 17:06:06 +0530 Subject: [PATCH 5/7] fix: move `initializeprob` typeasserts to `OverrideInitData` --- src/initialization.jl | 6 ++++++ src/scimlfunctions.jl | 10 +--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 6cbbcda72..9f7567b98 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -23,4 +23,10 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} the parameter object of the original problem. """ initializeprobpmap::IProbPmap + + function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K, + initprobpmap::L) where {I, J, K, L} + @assert initprob isa Union{NonlinearProblem, NonlinearLeastSquaresProblem} + return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap) + end end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 027bdacc1..e64519fda 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2483,9 +2483,6 @@ function ODEFunction{iip, specialize}(f; initialization_data, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) - @assert typeof(initdata.initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} - if specialize === NoSpecialize ODEFunction{iip, specialize, Any, Any, Any, Any, @@ -2756,8 +2753,6 @@ function SplitFunction{iip, specialize}(f1, f2; initdata = reconstruct_initialization_data( initialization_data, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) - @assert typeof(initdata.initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, @@ -3445,9 +3440,6 @@ function DAEFunction{iip, specialize}(f; initialization_data, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) - @assert typeof(initdata.initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} - if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, @@ -4455,7 +4447,7 @@ end function has_initializeprobpmap(f::AbstractSciMLFunction) __has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing end -function has_initialization_data(f::AbstractSciMLFunction) +function has_initialization_data(f) __has_initialization_data(f) && f.initialization_data !== nothing end From 2804efb2a626da7b27cec39b5883d2db1c6bd043 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 4 Nov 2024 13:13:11 +0530 Subject: [PATCH 6/7] refactor: format --- test/downstream/ensemble_diffeq.jl | 20 ++++++++++---------- test/downstream/ensemble_multi_prob.jl | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/downstream/ensemble_diffeq.jl b/test/downstream/ensemble_diffeq.jl index 2985cc72b..2fed8ed77 100644 --- a/test/downstream/ensemble_diffeq.jl +++ b/test/downstream/ensemble_diffeq.jl @@ -2,18 +2,18 @@ using OrdinaryDiffEq, Test A = [1 2 3 4] -prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0)) +prob = ODEProblem((u, p, t) -> A * u, ones(2, 2), (0.0, 1.0)) function prob_func(prob, i, repeat) remake(prob, u0 = i * prob.u0) end ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat = 0.01) @test sim isa EnsembleSolution -@test size(sim[1,:,:,:]) == (2,101,10) -@test size(sim[:,1,:,:]) == (2,101,10) -@test size(sim[:,:,1,:]) == (2,2,10) -@test size(sim[:,:,:,1]) == (2,2,101) -@test Array(sim)[1,:,:,:] == sim[1,:,:,:] -@test Array(sim)[:,1,:,:] == sim[:,1,:,:] -@test Array(sim)[:,:,1,:] == sim[:,:,1,:] -@test Array(sim)[:,:,:,1] == sim[:,:,:,1] \ No newline at end of file +@test size(sim[1, :, :, :]) == (2, 101, 10) +@test size(sim[:, 1, :, :]) == (2, 101, 10) +@test size(sim[:, :, 1, :]) == (2, 2, 10) +@test size(sim[:, :, :, 1]) == (2, 2, 101) +@test Array(sim)[1, :, :, :] == sim[1, :, :, :] +@test Array(sim)[:, 1, :, :] == sim[:, 1, :, :] +@test Array(sim)[:, :, 1, :] == sim[:, :, 1, :] +@test Array(sim)[:, :, :, 1] == sim[:, :, :, 1] diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 92f20065b..5180f8cf6 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -17,9 +17,9 @@ prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) for i in 1:3 - @test sol[1,:,i] == sol.u[i][x] - @test sol[2,:,i] == sol.u[i][y] + @test sol[1, :, i] == sol.u[i][x] + @test sol[2, :, i] == sol.u[i][y] end # Ensemble is a recursive array @test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] -@test only.(sol(1.0, idxs = [x])) ≈ [sol[i][1, end] for i in 1:3] \ No newline at end of file +@test only.(sol(1.0, idxs = [x])) ≈ [sol[i][1, end] for i in 1:3] From 55c171dc4ecfa5ad13df23d67e10ed59de6cd85e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 4 Nov 2024 16:19:00 +0530 Subject: [PATCH 7/7] fix: fix `remake`'s usage of `initializeprob` --- src/remake.jl | 56 +++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index e4540cd68..0ffd91003 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -125,10 +125,10 @@ function remake(prob::ODEProblem; f = missing, if f === missing if build_initializeprob - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob( + initialization_data = remake_initialization_data( prob.f.sys, prob.f, u0, tspan[1], p) else - initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing + initialization_data = nothing end if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) @@ -137,45 +137,21 @@ function remake(prob::ODEProblem; f = missing, wrapfun_iip( unwrapped_f(prob.f.f), (newu0, newu0, newp, - ptspan[1])); - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) + ptspan[1])); initialization_data) else _f = ODEFunction{iip, FunctionWrapperSpecialize}( wrapfun_oop( unwrapped_f(prob.f.f), (newu0, newp, - ptspan[1])); - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) + ptspan[1])); initialization_data) end else _f = prob.f - if __has_initializeprob(_f) + if __has_initialization_data(_f) props = getproperties(_f) - @reset props.initializeprob = initializeprob + @reset props.initialization_data = initialization_data props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_update_initializeprob!(_f) - props = getproperties(_f) - @reset props.update_initializeprob! = update_initializeprob! - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_initializeprobmap(_f) - props = getproperties(_f) - @reset props.initializeprobmap = initializeprobmap - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_initializeprobpmap(_f) - props = getproperties(_f) - @reset props.initializeprobpmap = initializeprobpmap - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) + _f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...) end end elseif f isa AbstractODEFunction @@ -206,6 +182,9 @@ end """ remake_initializeprob(sys, scimlfn, u0, t0, p) +!! WARN +This method is deprecated. Please see `remake_initialization_data` + Re-create the initialization problem present in the function `scimlfn`, using the associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and `p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an @@ -223,6 +202,21 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p) scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap end +""" + remake_initialization_data(sys, scimlfn, u0, t0, p) + +Re-create the initialization data present in the function `scimlfn`, using the +associated system `sys` and the user provided new values of `u0`, initial time `t0` and +`p`. By default, this calls `remake_initializeprob` for backward compatibility and +attempts to construct an `OverrideInitData` from the result. + +Note that `u0` or `p` may be `missing` if the user does not provide a value for them. +""" +function remake_initialization_data(sys, scimlfn, u0, t0, p) + return reconstruct_initialization_data( + nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...) +end + """ remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, _kwargs...)