Skip to content

Commit 350f661

Browse files
feat: implement DDEProblem and DDEFunction for System
1 parent 7cff503 commit 350f661

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/problems/ddeproblem.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
@fallback_iip_specialize function SciMLBase.DDEFunction{iip, spec}(
2+
sys::System, _d = nothing, u0 = nothing, p = nothing;
3+
eval_expression = false, eval_module = @__MODULE__, checkbounds = false,
4+
initialization_data = nothing, cse = true, check_compatibility = true,
5+
sparse = false, simplify = false, analytic = nothing, kwargs...) where {
6+
iip, spec}
7+
check_complete(sys, DDEFunction)
8+
check_compatibility && check_compatible_system(DDEFunction, sys)
9+
10+
dvs = unknowns(sys)
11+
ps = parameters(sys)
12+
13+
f = generate_rhs(sys, dvs, ps; expression = Val{false},
14+
eval_expression, eval_module, checkbounds = checkbounds, cse,
15+
kwargs...)
16+
17+
if spec === SciMLBase.FunctionWrapperSpecialize && iip
18+
if u0 === nothing || p === nothing || t === nothing
19+
error("u0, p, and t must be specified for FunctionWrapperSpecialize on DDEFunction.")
20+
end
21+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
22+
end
23+
24+
M = calculate_massmatrix(sys)
25+
_M = concrete_massmatrix(M; sparse, u0)
26+
27+
observedfun = ObservedFunctionCache(
28+
sys; eval_expression, eval_module, checkbounds, cse)
29+
30+
DDEFunction{iip, spec}(f;
31+
sys = sys,
32+
mass_matrix = _M,
33+
observed = observedfun,
34+
analytic = analytic,
35+
initialization_data)
36+
end
37+
38+
@fallback_iip_specialize function SciMLBase.DDEProblem{iip, spec}(
39+
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
40+
callback = nothing, check_length = true, cse = true, checkbounds = false,
41+
eval_expression = false, eval_module = @__MODULE__, check_compatibility = true,
42+
u0_constructor = identity,
43+
kwargs...) where {iip, spec}
44+
check_complete(sys, DDEProblem)
45+
check_compatibility && check_compatible_system(DDEProblem, sys)
46+
47+
f, u0, p = process_SciMLProblem(DDEFunction{iip, spec}, sys, u0map, parammap;
48+
t = tspan !== nothing ? tspan[1] : tspan, check_length, cse, checkbounds,
49+
eval_expression, eval_module, check_compatibility, symbolic_u0 = true, kwargs...)
50+
51+
h = generate_history(
52+
sys, u0; expression = Val{false}, cse, eval_expression, eval_module,
53+
checkbounds)
54+
u0 = float.(h(p, tspan[1]))
55+
if u0 !== nothing
56+
u0 = u0_constructor(u0)
57+
end
58+
59+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
60+
61+
# Call `remake` so it runs initialization if it is trivial
62+
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs...))
63+
end
64+
65+
function check_compatible_system(T::Union{Type{DDEFunction}, Type{DDEProblem}}, sys::System)
66+
check_time_dependent(sys, T)
67+
check_is_dde(sys)
68+
check_no_cost(sys, T)
69+
check_no_constraints(sys, T)
70+
check_no_jumps(sys, T)
71+
check_no_noise(sys, T)
72+
end

src/systems/codegen.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,20 @@ function generate_dae_jacobian(sys::System, dvs = unknowns(sys),
222222
return GeneratedFunctionWrapper{(3, 5, is_split(sys))}(f_oop, f_iip)
223223
end
224224

225+
function generate_history(sys::System, u0; expression = Val{true},
226+
eval_expression = false, eval_module = @__MODULE__, kwargs...)
227+
p = reorder_parameters(sys)
228+
res = build_function_wrapper(sys, u0, p..., get_iv(sys); expression = Val{true},
229+
expression_module = eval_module, p_start = 1, p_end = length(p),
230+
similarto = typeof(u0), wrap_delays = false, kwargs...)
231+
232+
if expression == Val{true}
233+
return res
234+
end
235+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
236+
return GeneratedFunctionWrapper{(1, 2, is_split(sys))}(f_oop, f_iip)
237+
end
238+
225239
function calculate_massmatrix(sys::System; simplify = false)
226240
eqs = [eq for eq in equations(sys)]
227241
M = zeros(length(eqs), length(eqs))

0 commit comments

Comments
 (0)