Skip to content

Commit 9607c39

Browse files
feat: add SemilinearODEFunction and SemilinearODEProblem
1 parent e972ff0 commit 9607c39

File tree

4 files changed

+639
-1
lines changed

4 files changed

+639
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2626
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2727
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2828
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
29+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2930
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -44,6 +45,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
4445
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4546
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4647
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
48+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4749
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4850
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4951
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -114,6 +116,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
114116
EnumX = "1.0.4"
115117
ExprTools = "0.1.10"
116118
FMI = "0.14"
119+
FillArrays = "1.13.0"
117120
FindFirstFunctions = "1"
118121
ForwardDiff = "0.10.3, 1"
119122
FunctionWrappers = "1.1"
@@ -141,6 +144,7 @@ OrdinaryDiffEq = "6.82.0"
141144
OrdinaryDiffEqCore = "1.15.0"
142145
OrdinaryDiffEqDefault = "1.2"
143146
OrdinaryDiffEqNonlinearSolve = "1.5.0"
147+
PreallocationTools = "0.4.27"
144148
PrecompileTools = "1"
145149
Pyomo = "0.1.0"
146150
REPL = "1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ const DQ = DynamicQuantities
9898
import DifferentiationInterface as DI
9999
using ADTypes: AutoForwardDiff
100100
import SciMLPublic: @public
101+
import PreallocationTools
102+
import PreallocationTools: DiffCache
103+
import FillArrays
101104

102105
export @derivatives
103106

@@ -256,6 +259,7 @@ export IntervalNonlinearProblem
256259
export OptimizationProblem, constraints
257260
export SteadyStateProblem
258261
export JumpProblem
262+
export SemilinearODEFunction, SemilinearODEProblem
259263
export alias_elimination, flatten
260264
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
261265
instream

src/problems/odeproblem.jl

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,163 @@ end
107107
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
108108
end
109109

110+
@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
111+
sys::System; u0 = nothing, p = nothing, t = nothing,
112+
semiquadratic_form = nothing,
113+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false,
114+
eval_expression = false, eval_module = @__MODULE__,
115+
expression = Val{false}, sparse = false, check_compatibility = true,
116+
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
117+
analytic = nothing, kwargs...) where {iip, specialize}
118+
check_complete(sys, SemilinearODEFunction)
119+
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)
120+
121+
if semiquadratic_form === nothing
122+
semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
123+
sys = add_semiquadratic_parameters(sys, semiquadratic_form...)
124+
end
125+
126+
A, B, C = semiquadratic_form
127+
M = calculate_massmatrix(sys)
128+
_M = concrete_massmatrix(M; sparse, u0)
129+
dvs = unknowns(sys)
130+
131+
f1,
132+
f2 = generate_semiquadratic_functions(
133+
sys, A, B, C; stiff_linear, stiff_quadratic,
134+
stiff_nonlinear, expression, wrap_gfw = Val{true},
135+
eval_expression, eval_module, kwargs...)
136+
137+
if jac
138+
Cjac = (C === nothing || !stiff_nonlinear) ? nothing : Symbolics.jacobian(C, dvs)
139+
_jac = generate_semiquadratic_jacobian(
140+
sys, A, B, C, Cjac; sparse, expression,
141+
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
142+
_W_sparsity = get_semiquadratic_W_sparsity(
143+
sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
144+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
145+
else
146+
_jac = nothing
147+
W_prototype = nothing
148+
end
149+
150+
observedfun = ObservedFunctionCache(
151+
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)
152+
153+
args = (; f1)
154+
kwargs = (; jac = _jac, jac_prototype = W_prototype)
155+
f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...)
156+
157+
args = (; f1, f2)
158+
kwargs = (;
159+
sys = sys,
160+
jac = _jac,
161+
mass_matrix = _M,
162+
jac_prototype = W_prototype,
163+
observed = observedfun,
164+
analytic,
165+
initialization_data)
166+
167+
return maybe_codegen_scimlfn(
168+
expression, SplitFunction{iip, specialize}, args; kwargs...)
169+
end
170+
171+
@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
172+
sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing,
173+
expression = Val{false}, callback = nothing, sparse = false,
174+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false,
175+
jac = false, kwargs...) where {
176+
iip, spec}
177+
check_complete(sys, SemilinearODEProblem)
178+
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)
179+
180+
A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
181+
eqs = equations(sys)
182+
dvs = unknowns(sys)
183+
184+
sys = add_semiquadratic_parameters(sys, A, B, C)
185+
if A !== nothing
186+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
187+
else
188+
linear_matrix_param = nothing
189+
end
190+
if B !== nothing
191+
quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i)))
192+
for i in 1:length(eqs)]
193+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
194+
else
195+
quadratic_forms = diffcache_par = nothing
196+
end
197+
198+
op = to_varmap(op, dvs)
199+
floatT = calculate_float_type(op, typeof(op))
200+
_u0_eltype = something(u0_eltype, floatT)
201+
202+
guess = copy(guesses(sys))
203+
defs = copy(defaults(sys))
204+
if A !== nothing
205+
guess[linear_matrix_param] = fill(NaN, size(A))
206+
defs[linear_matrix_param] = A
207+
end
208+
if B !== nothing
209+
for (par, mat) in zip(quadratic_forms, B)
210+
guess[par] = fill(NaN, size(mat))
211+
defs[par] = mat
212+
end
213+
cachelen = jac ? length(dvs) * length(eqs) : length(dvs)
214+
defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
215+
end
216+
@set! sys.guesses = guess
217+
@set! sys.defaults = defs
218+
219+
f, u0,
220+
p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
221+
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
222+
semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs...)
223+
224+
kwargs = process_kwargs(sys; expression, callback, kwargs...)
225+
226+
args = (; f, u0, tspan, p)
227+
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
228+
end
229+
230+
"""
231+
$(TYPEDSIGNATURES)
232+
233+
Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
234+
`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
235+
"""
236+
function add_semiquadratic_parameters(sys::System, A, B, C)
237+
eqs = equations(sys)
238+
n = length(eqs)
239+
var_to_name = copy(get_var_to_name(sys))
240+
if B !== nothing
241+
for i in eachindex(B)
242+
B[i] === nothing && continue
243+
par = get_quadratic_form_param((n, n), i)
244+
var_to_name[get_quadratic_form_name(i)] = par
245+
sys = with_additional_constant_parameter(sys, par)
246+
end
247+
par = get_diffcache_param(Float64)
248+
var_to_name[DIFFCACHE_PARAM_NAME] = par
249+
sys = with_additional_nonnumeric_parameter(sys, par)
250+
end
251+
if A !== nothing
252+
par = get_linear_matrix_param((n, n))
253+
var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
254+
sys = with_additional_constant_parameter(sys, par)
255+
end
256+
@set! sys.var_to_name = var_to_name
257+
if get_parent(sys) !== nothing
258+
@set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C)
259+
end
260+
return sys
261+
end
262+
110263
function check_compatible_system(
111264
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
112-
Type{DAEProblem}, Type{SteadyStateProblem}},
265+
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
266+
Type{SemilinearODEProblem}},
113267
sys::System)
114268
check_time_dependent(sys, T)
115269
check_not_dde(sys)

0 commit comments

Comments
 (0)