Skip to content

Commit e3bbd8f

Browse files
authored
Merge pull request #2060 from SciML/myb/param_idx
Add more flexible add_accumulations
2 parents 0d8182f + 0134f5f commit e3bbd8f

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,30 @@ $(SIGNATURES)
455455
Add accumulation variables for `vars`.
456456
"""
457457
function add_accumulations(sys::ODESystem, vars = states(sys))
458+
avars = [rename(v, Symbol(:accumulation_, getname(v))) for v in vars]
459+
add_accumulations(sys, avars .=> vars)
460+
end
461+
462+
"""
463+
$(SIGNATURES)
464+
465+
Add accumulation variables for `vars`. `vars` is a vector of pairs in the form
466+
of
467+
468+
```julia
469+
[cumulative_var1 => x + y, cumulative_var2 => x^2]
470+
```
471+
Then, cumulative variables `cumulative_var1` and `cumulative_var2` that computes
472+
the comulative `x + y` and `x^2` would be added to `sys`.
473+
"""
474+
function add_accumulations(sys::ODESystem, vars::Vector{<:Pair})
458475
eqs = get_eqs(sys)
459-
accs = filter(x -> startswith(string(x), "accumulation_"), states(sys))
460-
if !isempty(accs)
461-
error("$accs variable names start with \"accumulation_\"")
476+
avars = map(first, vars)
477+
if (ints = intersect(avars, states(sys)); !isempty(ints))
478+
error("$ints already exist in the system!")
462479
end
463-
avars = [rename(v, Symbol(:accumulation_, getname(v))) for v in vars]
464480
D = Differential(get_iv(sys))
465-
@set! sys.eqs = [eqs; Equation[D(a) ~ v for (a, v) in zip(avars, vars)]]
481+
@set! sys.eqs = [eqs; Equation[D(a) ~ v[2] for (a, v) in zip(avars, vars)]]
466482
@set! sys.states = [get_states(sys); avars]
467483
@set! sys.defaults = merge(get_defaults(sys), Dict(a => 0.0 for a in avars))
468484
end

test/odesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,14 @@ eqs = [0 ~ x + z
348348
D(accumulation_z) ~ z
349349
D(x) ~ y]
350350
@test sort(equations(asys), by = string) == eqs
351+
@variables ac(t)
352+
asys = add_accumulations(sys, [ac => (x + y)^2])
353+
eqs = [0 ~ x + z
354+
0 ~ x - y
355+
D(ac) ~ (x + y)^2
356+
D(x) ~ y]
357+
@test sort(equations(asys), by = string) == eqs
358+
351359
sys2 = ode_order_lowering(sys)
352360
M = ModelingToolkit.calculate_massmatrix(sys2)
353361
@test M == Diagonal([1, 0, 0])

0 commit comments

Comments
 (0)