Skip to content

Commit 046aefe

Browse files
Merge pull request #1031 from SciML/adapt
Add immutable problem types and adapt
2 parents aafccc1 + a2e0fa4 commit 046aefe

File tree

6 files changed

+174
-3
lines changed

6 files changed

+174
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.91.1"
4+
version = "2.92.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
9+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1011
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1112
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -55,6 +56,7 @@ SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]
5556
[compat]
5657
ADTypes = "0.2.5,1.0.0"
5758
Accessors = "0.1.36"
59+
Adapt = "4"
5860
ArrayInterface = "7.6"
5961
ChainRules = "1.58.0"
6062
ChainRulesCore = "1.18"

src/SciMLBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import Accessors: @set, @reset, @delete, @insert
2626
using Moshi.Data: @data
2727
using Moshi.Match: @match
2828
import StaticArraysCore
29+
import Adapt: adapt_structure, adapt
2930

3031
using Reexport
3132
using SciMLOperators
@@ -752,6 +753,8 @@ include("integrator_interface.jl")
752753
include("remake.jl")
753754
include("callbacks.jl")
754755

756+
include("adapt.jl")
757+
755758
include("deprecated.jl")
756759

757760
import PrecompileTools

src/adapt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function adapt_structure(to, prob::Union{NonlinearProblem{<:Any, <:Any, iip}, ImmutableNonlinearProblem{<:Any, <:Any, iip}}) where {iip}
2+
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(adapt(to, prob.f.f)),
3+
adapt(to, prob.u0),
4+
adapt(to, prob.p);
5+
adapt(to, prob.kwargs)...)
6+
end
7+
8+
function adapt_structure(to, prob::Union{ODEProblem{<:Any, <:Any, iip}, ImmutableODEProblem{<:Any, <:Any, iip}}) where {iip}
9+
ImmutableODEProblem{iip, FullSpecialize}(adapt(to, prob.f),
10+
adapt(to, prob.u0),
11+
adapt(to, prob.tspan),
12+
adapt(to, prob.p);
13+
adapt(to, prob.kwargs)...)
14+
end
15+
16+
function adapt_structure(to, f::ODEFunction{iip}) where {iip}
17+
if f.mass_matrix !== I && f.initialization_data !== nothing
18+
error("Adaptation to GPU failed: DAEs of ModelingToolkit currently not supported.")
19+
end
20+
ODEFunction{iip, FullSpecialize}(f.f, jac = f.jac, mass_matrix = f.mass_matrix)
21+
end

src/initialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct OverrideInitData{
4444
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
4545
initprobpmap::L, metadata::M, is_update_oop::O) where {I, J, K, L, M, O}
4646
@assert initprob isa
47-
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
47+
Union{SCCNonlinearProblem, ImmutableNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
4848
return new{I, J, K, L, M, O}(
4949
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
5050
end

src/problems/nonlinear_problems.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,6 @@ When a keyword argument is `nothing`, the default behaviour of the solver is use
579579
* `alias_u0::Union{Bool, Nothing}`: alias the `u0` array.
580580
* `alias::Union{Bool, Nothing}`: sets all fields of the `NonlinearAliasSpecifier` to `alias`.
581581
"""
582-
583582
struct NonlinearAliasSpecifier <: AbstractAliasSpecifier
584583
alias_p::Union{Bool, Nothing}
585584
alias_f::Union{Bool, Nothing}
@@ -596,3 +595,60 @@ struct NonlinearAliasSpecifier <: AbstractAliasSpecifier
596595
end
597596
end
598597
end
598+
599+
struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
600+
AbstractNonlinearProblem{uType, iip}
601+
f::F
602+
u0::uType
603+
p::P
604+
problem_type::PT
605+
kwargs::K
606+
607+
SciMLBase.@add_kwonly function ImmutableNonlinearProblem{iip}(
608+
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
609+
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
610+
if haskey(kwargs, :p)
611+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to \
612+
`NonlinearProblem`. This is not supported.")
613+
end
614+
SciMLBase.warn_paramtype(p)
615+
return new{
616+
typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
617+
f, u0, p, problem_type, kwargs)
618+
end
619+
620+
"""
621+
Define a steady state problem using the given function.
622+
`isinplace` optionally sets whether the function is inplace or not.
623+
This is determined automatically, but not inferred.
624+
"""
625+
function ImmutableNonlinearProblem{iip}(
626+
f, u0, p = NullParameters(); kwargs...) where {iip}
627+
return ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
628+
end
629+
end
630+
631+
"""
632+
Define a nonlinear problem using an instance of [`AbstractNonlinearFunction`](@ref).
633+
"""
634+
function ImmutableNonlinearProblem(
635+
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
636+
return ImmutableNonlinearProblem{SciMLBase.isinplace(f)}(f, u0, p; kwargs...)
637+
end
638+
639+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
640+
return ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
641+
end
642+
643+
"""
644+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem.
645+
"""
646+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
647+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(prob.f, prob.u0, prob.p)
648+
end
649+
650+
function Base.convert(
651+
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
652+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(
653+
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
654+
end

src/problems/ode_problems.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,92 @@ struct ODEAliasSpecifier <: AbstractAliasSpecifier
547547
end
548548
end
549549
end
550+
551+
struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <:
552+
AbstractODEProblem{uType, tType, isinplace}
553+
"""The ODE is `du = f(u,p,t)` for out-of-place and f(du,u,p,t) for in-place."""
554+
f::F
555+
"""The initial condition is `u(tspan[1]) = u0`."""
556+
u0::uType
557+
"""The solution `u(t)` will be computed for `tspan[1] ≤ t ≤ tspan[2]`."""
558+
tspan::tType
559+
"""Constant parameters to be supplied as the second argument of `f`."""
560+
p::P
561+
"""A callback to be applied to every solver which uses the problem."""
562+
kwargs::K
563+
"""An internal argument for storing traits about the solving process."""
564+
problem_type::PT
565+
@add_kwonly function ImmutableODEProblem{iip}(f::AbstractODEFunction{iip},
566+
u0, tspan, p = NullParameters(),
567+
problem_type = StandardODEProblem();
568+
kwargs...) where {iip}
569+
_u0 = prepare_initial_state(u0)
570+
_tspan = promote_tspan(tspan)
571+
warn_paramtype(p)
572+
new{typeof(_u0), typeof(_tspan),
573+
isinplace(f), typeof(p), typeof(f),
574+
typeof(kwargs),
575+
typeof(problem_type)}(f,
576+
_u0,
577+
_tspan,
578+
p,
579+
kwargs,
580+
problem_type)
581+
end
582+
583+
"""
584+
ImmutableODEProblem{isinplace}(f,u0,tspan,p=NullParameters(),callback=CallbackSet())
585+
586+
Define an ODE problem with the specified function.
587+
`isinplace` optionally sets whether the function is inplace or not.
588+
This is determined automatically, but not inferred.
589+
"""
590+
function ImmutableODEProblem{iip}(f,
591+
u0,
592+
tspan,
593+
p = NullParameters();
594+
kwargs...) where {iip}
595+
_u0 = prepare_initial_state(u0)
596+
_tspan = promote_tspan(tspan)
597+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
598+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
599+
end
600+
601+
@add_kwonly function ImmutableODEProblem{iip, recompile}(f, u0, tspan,
602+
p = NullParameters();
603+
kwargs...) where {iip, recompile}
604+
ImmutableODEProblem{iip}(ODEFunction{iip, recompile}(f), u0, tspan, p; kwargs...)
605+
end
606+
end
607+
608+
"""
609+
ImmutableODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())
610+
611+
Define an ODE problem from an [`ODEFunction`](@ref).
612+
"""
613+
function ImmutableODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
614+
ImmutableODEProblem{isinplace(f)}(f, u0, tspan, args...; kwargs...)
615+
end
616+
617+
function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
618+
iip = isinplace(f, 4)
619+
_u0 = prepare_initial_state(u0)
620+
_tspan = promote_tspan(tspan)
621+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
622+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
623+
end
624+
625+
staticarray_itize(x) = x
626+
staticarray_itize(x::Vector) = StaticArraysCore.SVector{length(x)}(x)
627+
staticarray_itize(x::StaticArraysCore.SizedVector) = StaticArraysCore.SVector{length(x)}(x)
628+
staticarray_itize(x::Matrix) = StaticArraysCore.SMatrix{size(x)...}(x)
629+
staticarray_itize(x::StaticArraysCore.SizedMatrix) = StaticArraysCore.SMatrix{size(x)...}(x)
630+
631+
function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem}
632+
ImmutableODEProblem(prob.f,
633+
staticarray_itize(prob.u0),
634+
prob.tspan,
635+
staticarray_itize(prob.p),
636+
prob.problem_type;
637+
prob.kwargs...)
638+
end

0 commit comments

Comments
 (0)