Skip to content

Commit 4aaae10

Browse files
Merge branch 'master' into dg/sym_oop
2 parents 53cc142 + 0a6cb3c commit 4aaae10

File tree

11 files changed

+213
-9
lines changed

11 files changed

+213
-9
lines changed

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.91.0"
4+
version = "2.95.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
9+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1011
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1112
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -55,6 +56,7 @@ SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]
5556
[compat]
5657
ADTypes = "0.2.5,1.0.0"
5758
Accessors = "0.1.36"
59+
Adapt = "4"
5860
ArrayInterface = "7.6"
5961
ChainRules = "1.58.0"
6062
ChainRulesCore = "1.18"
@@ -83,7 +85,7 @@ RecipesBase = "1.3.4"
8385
RecursiveArrayTools = "3.27.2"
8486
Reexport = "1"
8587
RuntimeGeneratedFunctions = "0.5.12"
86-
SciMLOperators = "0.4.0, 1"
88+
SciMLOperators = "0.4, 1.3"
8789
SciMLStructures = "1.1"
8890
StableRNGs = "1.0"
8991
StaticArrays = "1.7"
@@ -101,7 +103,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
101103
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
102104
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
103105
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
104-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
105106
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
106107
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
107108
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -113,4 +114,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
113114
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
114115

115116
[targets]
116-
test = ["Aqua", "ForwardDiff", "MLStyle", "PartialFunctions", "Pkg", "Plots", "SafeTestsets", "Serialization", "StableRNGs", "StaticArrays", "Tables", "Test", "UnicodePlots", "Zygote"]
117+
test = ["Aqua", "ForwardDiff", "MLStyle", "PartialFunctions", "Pkg", "SafeTestsets", "Serialization", "StableRNGs", "StaticArrays", "Tables", "Test", "UnicodePlots", "Zygote"]

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ There are too many to name here. Check out the
209209
## Contributing
210210

211211
- Please refer to the
212-
[SciML ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://github.com/SciML/ColPrac/blob/master/README.md)
212+
[SciML ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://github.com/SciML/ColPrac)
213213
for guidance on PRs, issues, and other matters relating to contributing to SciML.
214214

215215
- See the [SciML Style Guide](https://github.com/SciML/SciMLStyle) for common coding practices and other style decisions.

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,12 @@ function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged
116116
out, EnsembleSolution_adjoint
117117
end
118118

119+
function ChainRulesCore.rrule(::Type{SciMLBase.IntervalNonlinearProblem}, args...; kwargs...)
120+
function IntervalNonlinearProblemAdjoint(ȳ)
121+
(NoTangent(), ȳ.f, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
122+
end
123+
124+
SciMLBase.IntervalNonlinearProblem(args...; kwargs...), IntervalNonlinearProblemAdjoint
119125
end
126+
127+
end

src/SciMLBase.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import Accessors: @set, @reset, @delete, @insert
2626
using Moshi.Data: @data
2727
using Moshi.Match: @match
2828
import StaticArraysCore
29+
import Adapt: adapt_structure, adapt
2930

3031
using Reexport
3132
using SciMLOperators
@@ -662,6 +663,13 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
662663
"""
663664
struct TrackerOriginator <: ADOriginator end
664665

666+
"""
667+
$(TYPEDEF)
668+
669+
Internal. Used for signifying the AD context comes from a Mooncake.jl context.
670+
"""
671+
struct MooncakeOriginator <: ADOriginator end
672+
665673
include("initialization.jl")
666674
include("ODE_nlsolve.jl")
667675
include("utils.jl")
@@ -752,6 +760,8 @@ include("integrator_interface.jl")
752760
include("remake.jl")
753761
include("callbacks.jl")
754762

763+
include("adapt.jl")
764+
755765
include("deprecated.jl")
756766

757767
import PrecompileTools

src/adapt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function adapt_structure(to, prob::Union{NonlinearProblem{<:Any, <:Any, iip}, ImmutableNonlinearProblem{<:Any, <:Any, iip}}) where {iip}
2+
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(adapt(to, prob.f.f)),
3+
adapt(to, prob.u0),
4+
adapt(to, prob.p);
5+
adapt(to, prob.kwargs)...)
6+
end
7+
8+
function adapt_structure(to, prob::Union{ODEProblem{<:Any, <:Any, iip}, ImmutableODEProblem{<:Any, <:Any, iip}}) where {iip}
9+
ImmutableODEProblem{iip, FullSpecialize}(adapt(to, prob.f),
10+
adapt(to, prob.u0),
11+
adapt(to, prob.tspan),
12+
adapt(to, prob.p);
13+
adapt(to, prob.kwargs)...)
14+
end
15+
16+
function adapt_structure(to, f::ODEFunction{iip}) where {iip}
17+
if f.mass_matrix !== I && f.initialization_data !== nothing
18+
error("Adaptation to GPU failed: DAEs of ModelingToolkit currently not supported.")
19+
end
20+
ODEFunction{iip, FullSpecialize}(f.f, jac = f.jac, mass_matrix = f.mass_matrix)
21+
end

src/initialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct OverrideInitData{
4444
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
4545
initprobpmap::L, metadata::M, is_update_oop::O) where {I, J, K, L, M, O}
4646
@assert initprob isa
47-
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
47+
Union{SCCNonlinearProblem, ImmutableNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
4848
return new{I, J, K, L, M, O}(
4949
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
5050
end

src/problems/linear_problems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ are specified via the `AbstractSciMLOperator` interface. For more details, see
2222
the [SciMLBase Documentation](https://docs.sciml.ai/SciMLBase/stable/).
2323
2424
Note that matrix-free versions of LinearProblem definitions are not compatible
25-
with all solvers. To check a solver for compatibility, use the function xxxxx.
25+
with all solvers. To check a solver for compatibility, use the function `needs_concrete_A(alg::AbstractLinearAlgorithm)`.
2626
2727
## Problem Type
2828

src/problems/nonlinear_problems.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,6 @@ When a keyword argument is `nothing`, the default behaviour of the solver is use
579579
* `alias_u0::Union{Bool, Nothing}`: alias the `u0` array.
580580
* `alias::Union{Bool, Nothing}`: sets all fields of the `NonlinearAliasSpecifier` to `alias`.
581581
"""
582-
583582
struct NonlinearAliasSpecifier <: AbstractAliasSpecifier
584583
alias_p::Union{Bool, Nothing}
585584
alias_f::Union{Bool, Nothing}
@@ -596,3 +595,60 @@ struct NonlinearAliasSpecifier <: AbstractAliasSpecifier
596595
end
597596
end
598597
end
598+
599+
struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
600+
AbstractNonlinearProblem{uType, iip}
601+
f::F
602+
u0::uType
603+
p::P
604+
problem_type::PT
605+
kwargs::K
606+
607+
SciMLBase.@add_kwonly function ImmutableNonlinearProblem{iip}(
608+
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
609+
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
610+
if haskey(kwargs, :p)
611+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to \
612+
`NonlinearProblem`. This is not supported.")
613+
end
614+
SciMLBase.warn_paramtype(p)
615+
return new{
616+
typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
617+
f, u0, p, problem_type, kwargs)
618+
end
619+
620+
"""
621+
Define a steady state problem using the given function.
622+
`isinplace` optionally sets whether the function is inplace or not.
623+
This is determined automatically, but not inferred.
624+
"""
625+
function ImmutableNonlinearProblem{iip}(
626+
f, u0, p = NullParameters(); kwargs...) where {iip}
627+
return ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
628+
end
629+
end
630+
631+
"""
632+
Define a nonlinear problem using an instance of [`AbstractNonlinearFunction`](@ref).
633+
"""
634+
function ImmutableNonlinearProblem(
635+
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
636+
return ImmutableNonlinearProblem{SciMLBase.isinplace(f)}(f, u0, p; kwargs...)
637+
end
638+
639+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
640+
return ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
641+
end
642+
643+
"""
644+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem.
645+
"""
646+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
647+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(prob.f, prob.u0, prob.p)
648+
end
649+
650+
function Base.convert(
651+
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
652+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(
653+
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
654+
end

src/problems/ode_problems.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,92 @@ struct ODEAliasSpecifier <: AbstractAliasSpecifier
547547
end
548548
end
549549
end
550+
551+
struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <:
552+
AbstractODEProblem{uType, tType, isinplace}
553+
"""The ODE is `du = f(u,p,t)` for out-of-place and f(du,u,p,t) for in-place."""
554+
f::F
555+
"""The initial condition is `u(tspan[1]) = u0`."""
556+
u0::uType
557+
"""The solution `u(t)` will be computed for `tspan[1] ≤ t ≤ tspan[2]`."""
558+
tspan::tType
559+
"""Constant parameters to be supplied as the second argument of `f`."""
560+
p::P
561+
"""A callback to be applied to every solver which uses the problem."""
562+
kwargs::K
563+
"""An internal argument for storing traits about the solving process."""
564+
problem_type::PT
565+
@add_kwonly function ImmutableODEProblem{iip}(f::AbstractODEFunction{iip},
566+
u0, tspan, p = NullParameters(),
567+
problem_type = StandardODEProblem();
568+
kwargs...) where {iip}
569+
_u0 = prepare_initial_state(u0)
570+
_tspan = promote_tspan(tspan)
571+
warn_paramtype(p)
572+
new{typeof(_u0), typeof(_tspan),
573+
isinplace(f), typeof(p), typeof(f),
574+
typeof(kwargs),
575+
typeof(problem_type)}(f,
576+
_u0,
577+
_tspan,
578+
p,
579+
kwargs,
580+
problem_type)
581+
end
582+
583+
"""
584+
ImmutableODEProblem{isinplace}(f,u0,tspan,p=NullParameters(),callback=CallbackSet())
585+
586+
Define an ODE problem with the specified function.
587+
`isinplace` optionally sets whether the function is inplace or not.
588+
This is determined automatically, but not inferred.
589+
"""
590+
function ImmutableODEProblem{iip}(f,
591+
u0,
592+
tspan,
593+
p = NullParameters();
594+
kwargs...) where {iip}
595+
_u0 = prepare_initial_state(u0)
596+
_tspan = promote_tspan(tspan)
597+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
598+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
599+
end
600+
601+
@add_kwonly function ImmutableODEProblem{iip, recompile}(f, u0, tspan,
602+
p = NullParameters();
603+
kwargs...) where {iip, recompile}
604+
ImmutableODEProblem{iip}(ODEFunction{iip, recompile}(f), u0, tspan, p; kwargs...)
605+
end
606+
end
607+
608+
"""
609+
ImmutableODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())
610+
611+
Define an ODE problem from an [`ODEFunction`](@ref).
612+
"""
613+
function ImmutableODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
614+
ImmutableODEProblem{isinplace(f)}(f, u0, tspan, args...; kwargs...)
615+
end
616+
617+
function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
618+
iip = isinplace(f, 4)
619+
_u0 = prepare_initial_state(u0)
620+
_tspan = promote_tspan(tspan)
621+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
622+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
623+
end
624+
625+
staticarray_itize(x) = x
626+
staticarray_itize(x::Vector) = StaticArraysCore.SVector{length(x)}(x)
627+
staticarray_itize(x::StaticArraysCore.SizedVector) = StaticArraysCore.SVector{length(x)}(x)
628+
staticarray_itize(x::Matrix) = StaticArraysCore.SMatrix{size(x)...}(x)
629+
staticarray_itize(x::StaticArraysCore.SizedMatrix) = StaticArraysCore.SMatrix{size(x)...}(x)
630+
631+
function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem}
632+
ImmutableODEProblem(prob.f,
633+
staticarray_itize(prob.u0),
634+
prob.tspan,
635+
staticarray_itize(prob.p),
636+
prob.problem_type;
637+
prob.kwargs...)
638+
end

src/retcodes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ EnumX.@enumx ReturnCode begin
412412
413413
- For nonlinear least squares optimizations, this is given for local minima which exceed
414414
the chosen tolerance, i.e. `f(x)=resid` where `||resid||>tol` so it's not considered
415-
ReturnCode.Success but it is still considered a sucessful return of the solver since
415+
ReturnCode.Success but it is still considered a successful return of the solver since
416416
it's a valid local minima (and there no minima which achieves the tolerance).
417417
418418
## Properties

0 commit comments

Comments
 (0)