From c9cb8f80964770d8d2eea7a9450016328cbd5fa6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Oct 2024 17:06:06 +0530 Subject: [PATCH 1/5] feat: add implementations of `CheckInit` and `OverrideInit` --- src/SciMLBase.jl | 7 +- src/initialization.jl | 162 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 166 insertions(+), 3 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index ea78cc02a..986ec0949 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -348,6 +348,11 @@ $(TYPEDEF) """ struct CheckInit <: DAEInitializationAlgorithm end +""" +$(TYPEDEF) +""" +struct OverrideInit <: DAEInitializationAlgorithm end + # PDE Discretizations """ @@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context. struct TrackerOriginator <: ADOriginator end include("utils.jl") -include("initialization.jl") include("function_wrappers.jl") include("scimlfunctions.jl") include("alg_traits.jl") @@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl") include("ensemble/basic_ensemble_solve.jl") include("ensemble/ensemble_analysis.jl") +include("initialization.jl") include("solve.jl") include("interpolation.jl") include("integrator_interface.jl") diff --git a/src/initialization.jl b/src/initialization.jl index 9f7567b98..86c71560d 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} """ initializeprob::IProb """ - A function which takes `(initializeprob, prob)` and updates + A function which takes `(initializeprob, value_provider)` and updates the parameters of the former with their values in the latter. + If absent (`nothing`) this will not be called, and the parameters + in `initializeprob` will be used without modification. `value_provider` + refers to a value provider as defined by SymbolicIndexingInterface.jl. + Usually this will refer to a problem or integrator. """ update_initializeprob!::UIProb """ @@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} initializeprobmap::IProbMap """ A function which takes the solution of `initializeprob` and returns - the parameter object of the original problem. + the parameter object of the original problem. If absent (`nothing`), + this will not be called and the parameters of the problem being + initialized will be returned as-is. """ initializeprobpmap::IProbPmap @@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap) end end + +""" + get_initial_values(prob, valp, f, alg, isinplace; kwargs...) + +Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm, +and a boolean indicating whether the initialization process was successful. Keyword +arguments to this function are dependent on the initialization algorithm. `prob` is only +required for dispatching. `valp` refers the appropriate data structure from which the +current state and parameter values should be obtained. `valp` is a non-timeseries value +provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the +problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}` +if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise. +""" +function get_initial_values end + +struct CheckInitFailureError <: Exception + normresid::Any + abstol::Any +end + +function Base.showerror(io::IO, e::CheckInitFailureError) + print(io, + "CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)") +end + +struct OverrideInitMissingAlgorithm <: Exception end + +function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm) + print(io, + "OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.") +end + +""" +Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if +it is in-place or simply calling the function if not. +""" +function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...) + tmp = first(get_tmp_cache(integrator)) + f(tmp, args...) + return tmp +end + +function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...) + return f(args...) +end + +""" + $(TYPEDSIGNATURES) + +A utility function equivalent to `Base.vec` but also handles `Number` and +`AbstractSciMLScalarOperator`. +""" +_vec(v) = vec(v) +_vec(v::Number) = v +_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v +_vec(v::AbstractVector) = v + +""" + $(TYPEDSIGNATURES) + +Check if the algebraic constraints are satisfied, and error if they aren't. Returns +the `u0` and `p` as-is, and is always successful if it returns. Valid only for +`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. +""" +function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit, + isinplace::Union{Val{true}, Val{false}}; kwargs...) + u0 = state_values(integrator) + p = parameter_values(integrator) + t = current_time(integrator) + M = f.mass_matrix + + algebraic_vars = [all(iszero, x) for x in eachcol(M)] + algebraic_eqs = [all(iszero, x) for x in eachrow(M)] + (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return + update_coefficients!(M, u0, p, t) + tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t) + tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) + + normresid = integrator.opts.internalnorm(tmp, t) + if normresid > integrator.opts.abstol + throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + end + return u0, p, true +end + +""" +Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if +it is in-place or simply calling the function if not. +""" +function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...) + tmp = get_tmp_cache(integrator)[2] + f(tmp, args...) + return tmp +end + +function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) + return f(args...) +end + +function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit, + isinplace::Union{Val{true}, Val{false}}; kwargs...) + u0 = state_values(integrator) + p = parameter_values(integrator) + t = current_time(integrator) + + resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t) + normresid = integrator.opts.internalnorm(resid, t) + if normresid > integrator.opts.abstol + throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + end + return u0, p, true +end + +""" + $(TYPEDSIGNATURES) + +Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and +`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`. +If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is. +The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword +argument, failing which this function will throw an error. The success value returned +depends on the success of the nonlinear solve. +""" +function get_initial_values(prob, valp, f, alg::OverrideInit, + isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) + u0 = state_values(valp) + p = parameter_values(valp) + + if !has_initialization_data(f) + return u0, p, true + end + + initdata::OverrideInitData = f.initialization_data + initprob = initdata.initializeprob + + if nlsolve_alg === nothing + throw(OverrideInitMissingAlgorithm()) + end + + if initdata.update_initializeprob! !== nothing + initdata.update_initializeprob!(initprob, valp) + end + + nlsol = solve(initprob, nlsolve_alg) + + u0 = initdata.initializeprobmap(nlsol) + if initdata.initializeprobpmap !== nothing + p = initdata.initializeprobpmap(nlsol) + end + + return u0, p, SciMLBase.successful_retcode(nlsol) +end From 676403eae4792b35932c61a2f2e1ac00b1b78192 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 4 Nov 2024 20:07:45 +0530 Subject: [PATCH 2/5] test: test new initialization features --- test/initialization.jl | 156 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 + 2 files changed, 159 insertions(+) create mode 100644 test/initialization.jl diff --git a/test/initialization.jl b/test/initialization.jl new file mode 100644 index 000000000..ca8fb6b6c --- /dev/null +++ b/test/initialization.jl @@ -0,0 +1,156 @@ +using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test + +@testset "CheckInit" begin + @testset "ODEProblem" begin + function rhs(u, p, t) + return [u[1] * t, u[1]^2 - u[2]^2] + end + function rhs!(du, u, p, t) + du[1] = u[1] * t + du[2] = u[1]^2 - u[2]^2 + end + + oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0]) + iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0]) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) + integ = init(prob) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + end + end + + @testset "DAEProblem" begin + function daerhs(du, u, p, t) + return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2] + end + function daerhs!(resid, du, u, p, t) + resid[1] = du[1] - u[1] * t - p + resid[2] = u[1]^2 - u[2]^2 + end + + oopfn = DAEFunction{false}(daerhs) + iipfn = DAEFunction{true}(daerhs!) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) + integ = init(prob, DImplicitEuler()) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + + integ.u[2] = 1.0 + integ.du[1] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + end + end +end + +@testset "OverrideInit" begin + function rhs2(u, p, t) + return [u[1] * t + p, u[1]^2 - u[2]^2] + end + + @testset "No-op without `initialization_data`" begin + prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0) + integ = init(prob) + integ.u[2] = 3.0 + u0, p, success = SciMLBase.get_initial_values( + prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false)) + @test u0 ≈ [1.0, 3.0] + @test success + end + + # unknowns are u[2], p. Parameter is u[1] + initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1 + u2, p = x + u1 = _u1[1] + return [u1^2 - u2^2, p^2 - 2p + 1] + end + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + end + initprobmap = function (nlsol) + return [parameter_values(nlsol)[1], nlsol.u[1]] + end + initprobpmap = function (nlsol) + return nlsol.u[2] + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap) + fn = ODEFunction(rhs2; initialization_data) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + @testset "Errors without `nlsolve_alg`" begin + @test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), Val(false)) + end + + @testset "Solves" begin + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson()) + + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + + initprob.p[1] = 1.0 + end + + @testset "Solves with non-integrator value provider" begin + _integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t) + u0, p, success = SciMLBase.get_initial_values( + prob, _integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson()) + + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + + initprob.p[1] = 1.0 + end + + @testset "Solves without `update_initializeprob!`" begin + initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing + fn = ODEFunction(rhs2; initialization_data = initdata) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson()) + @test u0 ≈ [1.0, 1.0] + @test p ≈ 1.0 + @test success + end + + @testset "Solves without `initializeprobpmap`" begin + initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing + fn = ODEFunction(rhs2; initialization_data = initdata) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson()) + + @test u0 ≈ [2.0, 2.0] + @test p ≈ 0.0 + @test success + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8874ccac1..747c7251d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,9 @@ end @time @safetestset "Serialization tests" begin include("serialization_tests.jl") end + @time @safetestset "Initialization" begin + include("initialization.jl") + end end if !is_APPVEYOR && From 736b3d1ccc2e071e1c69d0062f47a1a1985f6867 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 4 Nov 2024 21:28:19 +0530 Subject: [PATCH 3/5] build: add NonlinearSolve as test dependency --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2401a1e5a..dbd9baf03 100644 --- a/Project.toml +++ b/Project.toml @@ -70,6 +70,7 @@ LinearAlgebra = "1.10" Logging = "1.10" Makie = "0.20, 0.21" Markdown = "1.10" +NonlinearSolve = "2, 3" PartialFunctions = "1.1" PrecompileTools = "1.2" Preferences = "1.3" @@ -98,6 +99,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -114,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"] +test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "NonlinearSolve", "OrdinaryDiffEq", "ForwardDiff", "Tables"] From 683f9d1029da92512c23d08c7798290e0e8b0793 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 5 Nov 2024 11:28:46 -0100 Subject: [PATCH 4/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dbd9baf03..b599088a7 100644 --- a/Project.toml +++ b/Project.toml @@ -70,7 +70,7 @@ LinearAlgebra = "1.10" Logging = "1.10" Makie = "0.20, 0.21" Markdown = "1.10" -NonlinearSolve = "2, 3" +NonlinearSolve = "3, 4" PartialFunctions = "1.1" PrecompileTools = "1.2" Preferences = "1.3" From b3105cc749e3a1d3b0ebd4f23fd83b1fd4fe8bd6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 6 Nov 2024 22:20:27 +0530 Subject: [PATCH 5/5] refactor: format --- src/solutions/save_idxs.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index db21ac72b..3d86f306a 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -372,7 +372,8 @@ function get_save_idxs_and_saved_subsystem(prob, save_idxs) if isempty(_save_idxs) # no states to save save_idxs = Int[] - elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + elseif !(save_idxs isa AbstractArray) || + symbolic_type(save_idxs) != NotSymbolic() # only a single state to save, and save it as a scalar timeseries instead of # single-element array save_idxs = only(_save_idxs)