diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index c8c8d0b66..ea78cc02a 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -654,6 +654,7 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context. struct TrackerOriginator <: ADOriginator end include("utils.jl") +include("initialization.jl") include("function_wrappers.jl") include("scimlfunctions.jl") include("alg_traits.jl") diff --git a/src/initialization.jl b/src/initialization.jl new file mode 100644 index 000000000..9f7567b98 --- /dev/null +++ b/src/initialization.jl @@ -0,0 +1,32 @@ +""" + $(TYPEDEF) + +A collection of all the data required for `OverrideInit`. +""" +struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} + """ + The `AbstractNonlinearProblem` to solve for initialization. + """ + initializeprob::IProb + """ + A function which takes `(initializeprob, prob)` and updates + the parameters of the former with their values in the latter. + """ + update_initializeprob!::UIProb + """ + A function which takes the solution of `initializeprob` and returns + the state vector of the original problem. + """ + initializeprobmap::IProbMap + """ + A function which takes the solution of `initializeprob` and returns + the parameter object of the original problem. + """ + initializeprobpmap::IProbPmap + + function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K, + initprobpmap::L) where {I, J, K, L} + @assert initprob isa Union{NonlinearProblem, NonlinearLeastSquaresProblem} + return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap) + end +end diff --git a/src/remake.jl b/src/remake.jl index e4540cd68..0ffd91003 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -125,10 +125,10 @@ function remake(prob::ODEProblem; f = missing, if f === missing if build_initializeprob - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob( + initialization_data = remake_initialization_data( prob.f.sys, prob.f, u0, tspan[1], p) else - initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing + initialization_data = nothing end if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) @@ -137,45 +137,21 @@ function remake(prob::ODEProblem; f = missing, wrapfun_iip( unwrapped_f(prob.f.f), (newu0, newu0, newp, - ptspan[1])); - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) + ptspan[1])); initialization_data) else _f = ODEFunction{iip, FunctionWrapperSpecialize}( wrapfun_oop( unwrapped_f(prob.f.f), (newu0, newp, - ptspan[1])); - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) + ptspan[1])); initialization_data) end else _f = prob.f - if __has_initializeprob(_f) + if __has_initialization_data(_f) props = getproperties(_f) - @reset props.initializeprob = initializeprob + @reset props.initialization_data = initialization_data props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_update_initializeprob!(_f) - props = getproperties(_f) - @reset props.update_initializeprob! = update_initializeprob! - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_initializeprobmap(_f) - props = getproperties(_f) - @reset props.initializeprobmap = initializeprobmap - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) - end - if __has_initializeprobpmap(_f) - props = getproperties(_f) - @reset props.initializeprobpmap = initializeprobpmap - props = values(props) - _f = parameterless_type(_f){ - iip, specialization(_f), map(typeof, props)...}(props...) + _f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...) end end elseif f isa AbstractODEFunction @@ -206,6 +182,9 @@ end """ remake_initializeprob(sys, scimlfn, u0, t0, p) +!! WARN +This method is deprecated. Please see `remake_initialization_data` + Re-create the initialization problem present in the function `scimlfn`, using the associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and `p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an @@ -223,6 +202,21 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p) scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap end +""" + remake_initialization_data(sys, scimlfn, u0, t0, p) + +Re-create the initialization data present in the function `scimlfn`, using the +associated system `sys` and the user provided new values of `u0`, initial time `t0` and +`p`. By default, this calls `remake_initializeprob` for backward compatibility and +attempts to construct an `OverrideInitData` from the result. + +Note that `u0` or `p` may be `missing` if the user does not provide a value for them. +""" +function remake_initialization_data(sys, scimlfn, u0, t0, p) + return reconstruct_initialization_data( + nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...) +end + """ remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, _kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index b05ffbfd5..e64519fda 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -405,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, - O, TCV, SYS, - IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip} + O, TCV, + SYS, ID, NLP} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -423,10 +423,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap + initialization_data::ID nlprob::NLP end @@ -530,8 +527,8 @@ information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt, - TPJ, O, TCV, SYS, - IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip} + TPJ, O, + TCV, SYS, ID, NLP} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -550,11 +547,8 @@ struct SplitFunction{ observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap nlprob::NLP + initialization_data::ID end @doc doc""" @@ -1529,7 +1523,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, UIProb, IProbMap, IProbPmap} <: + SYS, ID} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1545,10 +1539,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP observed::O colorvec::TCV sys::SYS - initializeprob::IProb - update_initializeprob!::UIProb - initializeprobmap::IProbMap - initializeprobpmap::IProbPmap + initialization_data::ID end """ @@ -2440,6 +2431,8 @@ function ODEFunction{iip, specialize}(f; initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, nlprob = __has_nlprob(f) ? f.nlprob : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing ) where {iip, specialize } @@ -2486,9 +2479,9 @@ function ODEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) - - @assert typeof(initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) if specialize === NoSpecialize ODEFunction{iip, specialize, @@ -2497,11 +2490,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2510,16 +2502,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), - typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), - typeof(nlprob)}(_f, mass_matrix, + typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2528,14 +2515,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), - typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac, - jvp, vjp, jac_prototype, sparsity, Wfact, + typeof(sys), typeof(initdata), typeof(nlprob)}( + _f, mass_matrix, analytic, tgrad, + jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap, nlprob) + observed, _colorvec, sys, initdata, nlprob) end end @@ -2552,13 +2536,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any, Any, Any, Any}( + typeof(f.sys), Any, Any}( newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, - f.update_initializeprob!, f.initializeprobmap, - f.initializeprobpmap, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2566,14 +2548,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!), - typeof(f.initializeprobmap), - typeof(f.initializeprobpmap), - typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}( + newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!, - f.initializeprobmap, f.initializeprobpmap, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) end end @@ -2704,8 +2683,8 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap, nlprob) + observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing, + initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2714,17 +2693,20 @@ end throw(NonconformingFunctionsError(["f2"])) end + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) + SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), - typeof(initializeprobpmap), typeof(nlprob)}( + typeof(sys), typeof(initdata), typeof(nlprob)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob) + initdata, nlprob) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2761,23 +2743,25 @@ function SplitFunction{iip, specialize}(f1, f2; f1.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing, - nlprob = __has_nlprob(f1) ? f1.nlprob : nothing + nlprob = __has_nlprob(f1) ? f1.nlprob : nothing, + initialization_data = __has_initialization_data(f1) ? f1.initialization_data : + nothing ) where {iip, specialize } sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) - @assert typeof(initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, + Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap, - initializeprobpmap, initializeprobpmap, nlprob) + observed, colorvec, sys, initdata, nlprob) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), @@ -2785,13 +2769,11 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap), typeof(nlprob)}(f1, f2, + typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob) + initdata, nlprob) end end @@ -3420,7 +3402,9 @@ function DAEFunction{iip, specialize}(f; update_initializeprob! = __has_update_initializeprob!(f) ? f.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, - initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where { + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing) where { iip, specialize } @@ -3452,33 +3436,29 @@ function DAEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) - - @assert typeof(initializeprob) <: - Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap) + _colorvec, sys, initdata) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), - typeof(initializeprobmap), - typeof(initializeprobpmap)}( + typeof(sys), typeof(initdata)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, update_initializeprob!, - initializeprobmap, initializeprobpmap) + _colorvec, sys, initdata) end end @@ -4397,6 +4377,14 @@ function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing) return sys end +function reconstruct_initialization_data( + initdata, initprob, update_initprob!, initprobmap, initprobpmap) + if initdata === nothing && initprob !== nothing + initdata = OverrideInitData(initprob, update_initprob!, initprobmap, initprobpmap) + end + return initdata +end + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -4420,11 +4408,20 @@ __has_colorvec(f) = isdefined(f, :colorvec) __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) -__has_initializeprob(f) = isdefined(f, :initializeprob) -__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!) -__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) -__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap) __has_nlprob(f) = isdefined(f, :nlprob) +function __has_initializeprob(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob) +end +function __has_update_initializeprob!(f) + has_initialization_data(f) && isdefined(f.initialization_data, :update_initializeprob!) +end +function __has_initializeprobmap(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobmap) +end +function __has_initializeprobpmap(f) + has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobpmap) +end +__has_initialization_data(f) = isdefined(f, :initialization_data) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -4438,16 +4435,20 @@ has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothin has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing function has_initializeprob(f::AbstractSciMLFunction) - __has_initializeprob(f) && f.initializeprob !== nothing + __has_initializeprob(f) && f.initialization_data.initializeprob !== nothing end function has_update_initializeprob!(f::AbstractSciMLFunction) - __has_update_initializeprob!(f) && f.update_initializeprob! !== nothing + __has_update_initializeprob!(f) && + f.initialization_data.update_initializeprob! !== nothing end function has_initializeprobmap(f::AbstractSciMLFunction) - __has_initializeprobmap(f) && f.initializeprobmap !== nothing + __has_initializeprobmap(f) && f.initialization_data.initializeprobmap !== nothing end function has_initializeprobpmap(f::AbstractSciMLFunction) - __has_initializeprobpmap(f) && f.initializeprobpmap !== nothing + __has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing +end +function has_initialization_data(f) + __has_initialization_data(f) && f.initialization_data !== nothing end function has_syms(f::AbstractSciMLFunction) @@ -4598,3 +4599,15 @@ function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym::Symb end SymbolicIndexingInterface.constant_structure(::AbstractSciMLFunction) = true + +function Base.getproperty(x::Union{ODEFunction, SplitFunction, DAEFunction}, sym::Symbol) + if sym == :initializeprob || sym == :update_initializeprob! || + sym == :initializeprobmap || sym == :initializeprobpmap + if x.initialization_data === nothing + return nothing + else + return getproperty(x.initialization_data, sym) + end + end + return getfield(x, sym) +end diff --git a/test/downstream/ensemble_diffeq.jl b/test/downstream/ensemble_diffeq.jl index 2985cc72b..2fed8ed77 100644 --- a/test/downstream/ensemble_diffeq.jl +++ b/test/downstream/ensemble_diffeq.jl @@ -2,18 +2,18 @@ using OrdinaryDiffEq, Test A = [1 2 3 4] -prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0)) +prob = ODEProblem((u, p, t) -> A * u, ones(2, 2), (0.0, 1.0)) function prob_func(prob, i, repeat) remake(prob, u0 = i * prob.u0) end ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat = 0.01) @test sim isa EnsembleSolution -@test size(sim[1,:,:,:]) == (2,101,10) -@test size(sim[:,1,:,:]) == (2,101,10) -@test size(sim[:,:,1,:]) == (2,2,10) -@test size(sim[:,:,:,1]) == (2,2,101) -@test Array(sim)[1,:,:,:] == sim[1,:,:,:] -@test Array(sim)[:,1,:,:] == sim[:,1,:,:] -@test Array(sim)[:,:,1,:] == sim[:,:,1,:] -@test Array(sim)[:,:,:,1] == sim[:,:,:,1] \ No newline at end of file +@test size(sim[1, :, :, :]) == (2, 101, 10) +@test size(sim[:, 1, :, :]) == (2, 101, 10) +@test size(sim[:, :, 1, :]) == (2, 2, 10) +@test size(sim[:, :, :, 1]) == (2, 2, 101) +@test Array(sim)[1, :, :, :] == sim[1, :, :, :] +@test Array(sim)[:, 1, :, :] == sim[:, 1, :, :] +@test Array(sim)[:, :, 1, :] == sim[:, :, 1, :] +@test Array(sim)[:, :, :, 1] == sim[:, :, :, 1] diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 92f20065b..5180f8cf6 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -17,9 +17,9 @@ prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) for i in 1:3 - @test sol[1,:,i] == sol.u[i][x] - @test sol[2,:,i] == sol.u[i][y] + @test sol[1, :, i] == sol.u[i][x] + @test sol[2, :, i] == sol.u[i][y] end # Ensemble is a recursive array @test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] -@test only.(sol(1.0, idxs = [x])) ≈ [sol[i][1, end] for i in 1:3] \ No newline at end of file +@test only.(sol(1.0, idxs = [x])) ≈ [sol[i][1, end] for i in 1:3]