Skip to content

Commit 1f1b48f

Browse files
committed
Merge branch 'master' into myb/connector_type
2 parents e953585 + 00590bf commit 1f1b48f

File tree

10 files changed

+70
-28
lines changed

10 files changed

+70
-28
lines changed

src/structural_transformation/codegen.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,20 @@ function build_torn_function(sys;
318318
sol_states = sol_states,
319319
var2assignment = var2assignment
320320

321-
function generated_observed(obsvar, u, p, t)
321+
function generated_observed(obsvar, args...)
322322
obs = get!(dict, value(obsvar)) do
323323
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
324324
is_solver_state_idxs, assignments, deps,
325325
sol_states, var2assignment,
326326
checkbounds = checkbounds)
327327
end
328-
obs(u, p, t)
328+
if args === ()
329+
let obs = obs
330+
(u, p, t) -> obs(u, p, t)
331+
end
332+
else
333+
obs(args...)
334+
end
329335
end
330336
end
331337

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ for prop in [:eqs
182182
:iv
183183
:states
184184
:ps
185+
:tspan
185186
:var_to_name
186187
:ctrls
187188
:defaults

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,32 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
328328
obs = observed(sys)
329329
observedfun = if steady_state
330330
let sys = sys, dict = Dict()
331-
function generated_observed(obsvar, u, p, t = Inf)
331+
function generated_observed(obsvar, args...)
332332
obs = get!(dict, value(obsvar)) do
333333
build_explicit_observed_function(sys, obsvar)
334334
end
335-
obs(u, p, t)
335+
if args === ()
336+
let obs = obs
337+
(u, p, t = Inf) -> obs(u, p, t)
338+
end
339+
else
340+
length(args) == 2 ? obs(args..., Inf) : obs(args...)
341+
end
336342
end
337343
end
338344
else
339345
let sys = sys, dict = Dict()
340-
function generated_observed(obsvar, u, p, t)
346+
function generated_observed(obsvar, args...)
341347
obs = get!(dict, value(obsvar)) do
342348
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
343349
end
344-
obs(u, p, t)
350+
if args === ()
351+
let obs = obs
352+
(u, p, t) -> obs(u, p, t)
353+
end
354+
else
355+
obs(args...)
356+
end
345357
end
346358
end
347359
end
@@ -424,11 +436,17 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
424436

425437
obs = observed(sys)
426438
observedfun = let sys = sys, dict = Dict()
427-
function generated_observed(obsvar, u, p, t)
439+
function generated_observed(obsvar, args...)
428440
obs = get!(dict, value(obsvar)) do
429441
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
430442
end
431-
obs(u, p, t)
443+
if args === ()
444+
let obs = obs
445+
(u, p, t) -> obs(u, p, t)
446+
end
447+
else
448+
obs(args...)
449+
end
432450
end
433451
end
434452

@@ -662,7 +680,8 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
662680
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
663681
end
664682

665-
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map, tspan,
683+
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
684+
tspan = get_tspan(sys),
666685
parammap = DiffEqBase.NullParameters();
667686
callback = nothing,
668687
check_length = true,

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x),
1919
D(y) ~ x*(ρ-z)-y,
2020
D(z) ~ x*y - β*z]
2121
22-
@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β])
22+
@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],tspan=(0, 1000.0))
2323
```
2424
"""
2525
struct ODESystem <: AbstractODESystem
@@ -41,6 +41,8 @@ struct ODESystem <: AbstractODESystem
4141
states::Vector
4242
"""Parameter variables. Must not contain the independent variable."""
4343
ps::Vector
44+
"""Time span."""
45+
tspan::Union{NTuple{2, Any}, Nothing}
4446
"""Array variables."""
4547
var_to_name::Any
4648
"""Control parameters (some subset of `ps`)."""
@@ -125,7 +127,7 @@ struct ODESystem <: AbstractODESystem
125127
"""
126128
complete::Bool
127129

128-
function ODESystem(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
130+
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
129131
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
130132
torn_matching, connector_type, preface, cevents,
131133
devents, metadata = nothing, tearing_state = nothing,
@@ -140,7 +142,7 @@ struct ODESystem <: AbstractODESystem
140142
if checks == true || (checks & CheckUnits) > 0
141143
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
142144
end
143-
new(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
145+
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
144146
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
145147
connector_type, preface, cevents, devents, metadata, tearing_state,
146148
substitutions, complete)
@@ -151,6 +153,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
151153
controls = Num[],
152154
observed = Equation[],
153155
systems = ODESystem[],
156+
tspan = nothing,
154157
name = nothing,
155158
default_u0 = Dict(),
156159
default_p = Dict(),
@@ -195,7 +198,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
195198
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
196199
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
197200
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
198-
deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
201+
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
199202
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
200203
connector_type, preface, cont_callbacks, disc_callbacks,
201204
metadata, checks = checks)

src/systems/diffeqs/sdesystem.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ noiseeqs = [0.1*x,
2323
0.1*y,
2424
0.1*z]
2525
26-
@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β])
26+
@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
2727
```
2828
"""
2929
struct SDESystem <: AbstractODESystem
@@ -42,6 +42,8 @@ struct SDESystem <: AbstractODESystem
4242
states::Vector
4343
"""Parameter variables. Must not contain the independent variable."""
4444
ps::Vector
45+
"""Time span."""
46+
tspan::Union{NTuple{2, Any}, Nothing}
4547
"""Array variables."""
4648
var_to_name::Any
4749
"""Control parameters (some subset of `ps`)."""
@@ -110,7 +112,8 @@ struct SDESystem <: AbstractODESystem
110112
"""
111113
complete::Bool
112114

113-
function SDESystem(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
115+
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
116+
tgrad,
114117
jac,
115118
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
116119
cevents, devents, metadata = nothing, complete = false;
@@ -124,7 +127,7 @@ struct SDESystem <: AbstractODESystem
124127
if checks == true || (checks & CheckUnits) > 0
125128
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
126129
end
127-
new(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
130+
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
128131
ctrl_jac,
129132
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
130133
metadata, complete)
@@ -135,6 +138,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
135138
controls = Num[],
136139
observed = Num[],
137140
systems = SDESystem[],
141+
tspan = nothing,
138142
default_u0 = Dict(),
139143
default_p = Dict(),
140144
defaults = _merge(Dict(default_u0), Dict(default_p)),
@@ -177,7 +181,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
177181
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
178182

179183
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
180-
deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
184+
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
181185
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
182186
cont_callbacks, disc_callbacks, metadata; checks = checks)
183187
end
@@ -531,7 +535,7 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
531535
SDEFunctionExpr{true}(sys, args...; kwargs...)
532536
end
533537

534-
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
538+
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys),
535539
parammap = DiffEqBase.NullParameters();
536540
sparsenoise = nothing, check_length = true,
537541
callback = nothing, kwargs...) where {iip}

src/systems/discrete_system/discrete_system.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x),
1919
D(y) ~ x*(ρ-z)-y,
2020
D(z) ~ x*y - β*z]
2121
22-
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or
22+
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or
2323
@named de = DiscreteSystem(eqs)
2424
```
2525
"""
@@ -37,6 +37,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
3737
states::Vector
3838
"""Parameter variables. Must not contain the independent variable."""
3939
ps::Vector
40+
"""Time span."""
41+
tspan::Union{NTuple{2, Any}, Nothing}
4042
"""Array variables."""
4143
var_to_name::Any
4244
"""Control parameters (some subset of `ps`)."""
@@ -81,7 +83,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
8183
"""
8284
complete::Bool
8385

84-
function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed,
86+
function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls,
87+
observed,
8588
name,
8689
systems, defaults, preface, connector_type,
8790
metadata = nothing,
@@ -94,7 +97,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
9497
if checks == true || (checks & CheckUnits) > 0
9598
all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs)
9699
end
97-
new(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems,
100+
new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name,
101+
systems,
98102
defaults,
99103
preface, connector_type, metadata, tearing_state, substitutions, complete)
100104
end
@@ -109,6 +113,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
109113
controls = Num[],
110114
observed = Num[],
111115
systems = DiscreteSystem[],
116+
tspan = nothing,
112117
name = nothing,
113118
default_u0 = Dict(),
114119
default_p = Dict(),
@@ -142,7 +147,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
142147
throw(ArgumentError("System names must be unique."))
143148
end
144149
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
145-
eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems,
150+
eqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, name, systems,
146151
defaults, preface, connector_type, metadata, kwargs...)
147152
end
148153

@@ -192,7 +197,7 @@ end
192197
193198
Generates an DiscreteProblem from an DiscreteSystem.
194199
"""
195-
function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map, tspan,
200+
function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys),
196201
parammap = SciMLBase.NullParameters();
197202
eval_module = @__MODULE__,
198203
eval_expression = true,

test/components.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ end
2121

2222
function check_rc_sol(sol)
2323
rpi = sol[rc_model.resistor.p.i]
24+
rpifun = sol.prob.f.observed(rc_model.resistor.p.i)
25+
@test rpifun.(sol.u, (sol.prob.p,), sol.t) == rpi
2426
@test any(!isequal(rpi[1]), rpi) # test that we don't have a constant system
2527
@test sol[rc_model.resistor.p.i] == sol[resistor.p.i] == sol[capacitor.p.i]
2628
@test sol[rc_model.resistor.n.i] == sol[resistor.n.i] == -sol[capacitor.p.i]

test/discretesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ eqs2 = [D(S) ~ S - infection2,
5252
D(I) ~ I + infection2 - recovery2,
5353
D(R) ~ R + recovery2]
5454

55-
@named sys = DiscreteSystem(eqs2; controls = [β, γ])
55+
@named sys = DiscreteSystem(eqs2; controls = [β, γ], tspan)
5656
@test ModelingToolkit.defaults(sys) != Dict()
5757

58-
prob_map2 = DiscreteProblem(sys, [], tspan)
58+
prob_map2 = DiscreteProblem(sys)
5959
sol_map2 = solve(prob_map, FunctionMap());
6060

6161
@test sol_map.u == sol_map2.u

test/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,11 @@ eqs = [D(D(x)) ~ -b / M * D(x) - k / M * x]
374374
ps = [M, b, k]
375375
default_u0 = [D(x) => 0.0, x => 10.0]
376376
default_p = [M => 1.0, b => 1.0, k => 1.0]
377-
@named sys = ODESystem(eqs, t, [x], ps, defaults = [default_u0; default_p])
377+
@named sys = ODESystem(eqs, t, [x], ps; defaults = [default_u0; default_p], tspan)
378378
sys = ode_order_lowering(sys)
379-
prob = ODEProblem(sys, [], tspan)
379+
prob = ODEProblem(sys)
380380
sol = solve(prob, Tsit5())
381+
@test sol.t[end] == tspan[end]
381382
@test sum(abs, sol[end]) < 1
382383

383384
# check_eqs_u0 kwarg test

test/sdesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ noiseeqs = [0.1 * x,
2020
@named sys = ODESystem(eqs, t, [x, y, z], [σ, ρ, β])
2121
@test SDESystem(sys, noiseeqs, name = :foo) isa SDESystem
2222

23-
@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β])
23+
@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β], tspan = (0.0, 10.0))
2424
f = eval(generate_diffusion_function(de)[1])
2525
@test f(ones(3), rand(3), nothing) == 0.1ones(3)
2626

@@ -36,6 +36,7 @@ solexpr = solve(eval(probexpr), SRIW1(), seed = 1)
3636

3737
# Test no error
3838
@test_nowarn SDEProblem(de, nothing, (0, 10.0))
39+
@test SDEProblem(de, nothing).tspan == (0.0, 10.0)
3940

4041
noiseeqs_nd = [0.01*x 0.01*x*y 0.02*x*z
4142
σ 0.01*y 0.02*x*z

0 commit comments

Comments
 (0)