Skip to content

Commit 19fbbfd

Browse files
feat: implement ODEProblem and ODEFunction for System
1 parent dc5acfb commit 19fbbfd

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

src/problems/odeproblem.jl

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
@fallback_iip_specialize function SciMLBase.ODEFunction{iip, spec}(
2+
sys::System, _d = nothing, u0 = nothing, p = nothing; tgrad = false, jac = false,
3+
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
4+
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
5+
simplify = false, cse = true, initialization_data = nothing,
6+
check_compatibility = true, kwargs...) where {iip, spec}
7+
check_complete(sys, ODEFunction)
8+
check_compatibility && check_compatible_system(ODEFunction, sys)
9+
10+
dvs = unknowns(sys)
11+
ps = parameters(sys)
12+
f_gen = generate_rhs(sys, dvs, ps; expression = Val{true},
13+
expression_module = eval_module, checkbounds = checkbounds, cse,
14+
kwargs...)
15+
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
16+
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
17+
18+
if spec === SciMLBase.FunctionWrapperSpecialize && iip
19+
if u0 === nothing || p === nothing || t === nothing
20+
error("u0, p, and t must be specified for FunctionWrapperSpecialize on ODEFunction.")
21+
end
22+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
23+
end
24+
25+
if tgrad
26+
tgrad_gen = generate_tgrad(sys, dvs, ps;
27+
simplify = simplify,
28+
expression = Val{true},
29+
expression_module = eval_module, cse,
30+
checkbounds = checkbounds, kwargs...)
31+
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
32+
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
33+
else
34+
_tgrad = nothing
35+
end
36+
37+
if jac
38+
jac_gen = generate_jacobian(sys, dvs, ps;
39+
simplify = simplify, sparse = sparse,
40+
expression = Val{true},
41+
expression_module = eval_module, cse,
42+
checkbounds = checkbounds, kwargs...)
43+
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
44+
45+
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
46+
else
47+
_jac = nothing
48+
end
49+
50+
M = calculate_massmatrix(sys)
51+
52+
_M = if sparse && !(u0 === nothing || M === I)
53+
SparseArrays.sparse(M)
54+
elseif u0 === nothing || M === I
55+
M
56+
else
57+
ArrayInterface.restructure(u0 .* u0', M)
58+
end
59+
60+
observedfun = ObservedFunctionCache(
61+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
62+
63+
if sparse
64+
uElType = u0 === nothing ? Float64 : eltype(u0)
65+
W_prototype = similar(W_sparsity(sys), uElType)
66+
else
67+
W_prototype = nothing
68+
end
69+
70+
ODEFunction{iip, spec}(f;
71+
sys = sys,
72+
jac = _jac,
73+
tgrad = _tgrad,
74+
mass_matrix = _M,
75+
jac_prototype = W_prototype,
76+
observed = observedfun,
77+
sparsity = sparsity ? W_sparsity(sys) : nothing,
78+
analytic = analytic,
79+
initialization_data)
80+
end
81+
82+
@fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}(
83+
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
84+
callback = nothing, check_length = true, eval_expression = false,
85+
eval_module = @__MODULE__, check_compatibility = true, kwargs...) where {iip, spec}
86+
check_complete(sys, ODEProblem)
87+
check_compatibility && check_compatible_system(ODEProblem, sys)
88+
89+
f, u0, p = process_SciMLProblem(ODEFunction{iip, spec}, sys, u0map, parammap;
90+
t = tspan !== nothing ? tspan[1] : tspan,
91+
check_length, eval_expression, eval_module, kwargs...)
92+
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
93+
94+
kwargs = filter_kwargs(kwargs)
95+
96+
kwargs1 = (;)
97+
if cbs !== nothing
98+
kwargs1 = merge(kwargs1, (callback = cbs,))
99+
end
100+
101+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
102+
if tstops !== nothing
103+
kwargs1 = merge(kwargs1, (; tstops))
104+
end
105+
106+
# Call `remake` so it runs initialization if it is trivial
107+
return remake(ODEProblem{iip}(
108+
f, u0, tspan, p, StandardODEProblem(); kwargs1..., kwargs...))
109+
end
110+
111+
function check_compatible_system(T::Union{Type{ODEFunction}, Type{ODEProblem}}, sys::System)
112+
if !is_time_dependent(sys)
113+
throw(SystemCompatibilityError("""
114+
`$T` requires a time-dependent system.
115+
"""))
116+
end
117+
118+
cost = get_costs(sys)
119+
if cost isa Vector && !isempty(cost) ||
120+
cost isa Union{BasicSymbolic, Real} && !_iszero(cost)
121+
throw(SystemCompatibilityError("""
122+
`$T` will not optimize solutions of systems that have associated cost \
123+
functions. Solvers for optimal control problems are forthcoming. In order to \
124+
bypass this error (e.g. to check the cost of a regular solution), pass \
125+
`allow_cost = true` into the constructor.
126+
"""))
127+
end
128+
129+
if !isempty(constraints(sys))
130+
throw(SystemCompatibilityError("""
131+
A system with constraints cannot be used to construct an `$T`. Consider a \
132+
`BVProblem` instead.
133+
"""))
134+
end
135+
136+
if !isempty(jumps(sys))
137+
throw(SystemCompatibilityError("""
138+
A system with jumps cannot be used to construct an `$T`. Consider a \
139+
`JumpProblem` instead.
140+
"""))
141+
end
142+
143+
if get_noise_eqs(sys) !== nothing
144+
throw(SystemCompatibilityError("""
145+
A system with jumps cannot be used to construct an `$T`. Consider an \
146+
`SDEProblem` instead.
147+
"""))
148+
end
149+
end

0 commit comments

Comments
 (0)