Skip to content

Commit 21b2d96

Browse files
author
dd
committed
using integrator interface
1 parent 8e7ef41 commit 21b2d96

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/systems/callbacks.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ function FunctionalAffect(f, sts, pars, ctx = nothing)
2323
# sts & pars contain either pairs: resistor.R => R, or Syms: R
2424
vs = [x isa Pair ? x.first : x for x in sts]
2525
vs_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
26-
length(vs_syms) == length(unique(vs_syms)) || error("Variables are not unique.")
2726

2827
ps = [x isa Pair ? x.first : x for x in pars]
2928
ps_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
30-
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique.")
31-
29+
length(vs_syms) + length(ps_syms) == length(unique(vcat(vs_syms, ps_syms))) || error("All symbols for variables & parameters must be unique.")
30+
3231
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, ctx)
3332
end
3433

@@ -405,20 +404,16 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
405404
# (MTK should keep these symbols)
406405
v = filter(x -> !isnothing(x[1]), collect(zip(v_inds, states_syms(affect))))
407406
v_inds = [x[1] for x in v]
408-
v_syms = Tuple([x[2] for x in v])
407+
v_syms = [x[2] for x in v]
409408
p = filter(x -> !isnothing(x[1]), collect(zip(p_inds, parameters_syms(affect))))
410409
p_inds = [x[1] for x in p]
411-
p_syms = Tuple([x[2] for x in p])
412-
413-
let v_inds=v_inds, p_inds=p_inds, v_syms=v_syms, p_syms=p_syms, user_affect=func(affect), ctx = context(affect)
414-
function (integ)
415-
uv = @views integ.u[v_inds]
416-
pv = @views integ.p[p_inds]
410+
p_syms = [x[2] for x in p]
417411

418-
u = LArray{v_syms}(uv)
419-
p = LArray{p_syms}(pv)
412+
kwargs = zip(vcat(v_syms, p_syms), vcat(v_inds, p_inds))
420413

421-
user_affect(integ.t, u, p, ctx)
414+
let kwargs=kwargs, user_affect=func(affect), ctx = context(affect)
415+
function (integ)
416+
user_affect(integ, ctx; kwargs...)
422417
end
423418
end
424419
end

test/funcaffect.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using ModelingToolkit, Test, DifferentialEquations
2+
3+
@parameters t a b
4+
@variables u(t)
5+
D = Differential(t)
6+
7+
eqs = [ D(u) ~ -u ]
8+
9+
affect1!(integ, ctx; u) = integ.u[u] += 10
10+
11+
@named sys = ODESystem(eqs, t, [u], [], discrete_events=[[4.0, ]=>(affect1!, [u], [], nothing)])
12+
prob = ODEProblem(sys, [u=> 10.0], (0, 10.0))
13+
sol = solve(prob, Tsit5())
14+
i4 = findfirst(==(4.0), sol[:t])
15+
@test sol.u[i4+1][1] > 10.0
16+
17+
# context
18+
function affect2!(integ, ctx; u)
19+
integ.u[u] += ctx[1]
20+
ctx[1] *= 2
21+
end
22+
ctx1 = [10.0, ]
23+
@named sys = ODESystem(eqs, t, [u], [], discrete_events=[[4.0, 8.0]=>(affect2!, [u], [], ctx1)])
24+
prob = ODEProblem(sys, [u=> 10.0], (0, 10.0))
25+
sol = solve(prob, Tsit5())
26+
i4 = findfirst(==(4.0), sol[:t])
27+
@test sol.u[i4+1][1] > 10.0
28+
i8 = findfirst(==(8.0), sol[:t])
29+
@test sol.u[i8+1][1] > 20.0
30+
@test ctx1[1] == 40.0
31+
32+
33+

0 commit comments

Comments
 (0)