Skip to content

Commit f70aec9

Browse files
author
dd
committed
pass variables/states as named-tuples into affect
1 parent 6f8ee7e commit f70aec9

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

src/systems/callbacks.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ function FunctionalAffect(f, sts, pars, ctx = nothing)
2323
# sts & pars contain either pairs: resistor.R => R, or Syms: R
2424
vs = [x isa Pair ? x.first : x for x in sts]
2525
vs_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
26+
length(vs_syms) == length(unique(vs_syms)) || error("Variables are not unique")
2627

2728
ps = [x isa Pair ? x.first : x for x in pars]
2829
ps_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
29-
length(vs_syms) + length(ps_syms) == length(unique(vcat(vs_syms, ps_syms))) || error("All symbols for variables & parameters must be unique.")
30+
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique")
3031

3132
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, ctx)
3233
end
@@ -146,10 +147,10 @@ end
146147

147148
struct SymbolicDiscreteCallback
148149
# condition can be one of:
149-
# TODO: Iterative
150150
# Δt::Real - Periodic with period Δt
151151
# Δts::Vector{Real} - events trigger in this times (Preset)
152152
# condition::Vector{Equation} - event triggered when condition is true
153+
# TODO: Iterative
153154
condition
154155
affects
155156

@@ -411,18 +412,14 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
411412

412413
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
413414
# (MTK should keep these symbols)
414-
v = filter(x -> !isnothing(x[1]), collect(zip(v_inds, states_syms(affect))))
415-
v_inds = [x[1] for x in v]
416-
v_syms = [x[2] for x in v]
417-
p = filter(x -> !isnothing(x[1]), collect(zip(p_inds, parameters_syms(affect))))
418-
p_inds = [x[1] for x in p]
419-
p_syms = [x[2] for x in p]
420-
421-
kwargs = zip(vcat(v_syms, p_syms), vcat(v_inds, p_inds))
415+
u = filter(x -> !isnothing(x[2]), collect(zip(states_syms(affect), v_inds)))
416+
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds)))
417+
u = NamedTuple(u)
418+
p = NamedTuple(p)
422419

423-
let kwargs=kwargs, user_affect=func(affect), ctx = context(affect)
420+
let u=u, p=p, user_affect=func(affect), ctx = context(affect)
424421
function (integ)
425-
user_affect(integ, ctx; kwargs...)
422+
user_affect(integ, u, p, ctx)
426423
end
427424
end
428425
end

test/funcaffect.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ D = Differential(t)
66

77
eqs = [ D(u) ~ -u ]
88

9-
affect1!(integ, ctx; u) = integ.u[u] += 10
9+
affect1!(integ, u, p, ctx) = integ.u[u.u] += 10
1010

1111
@named sys = ODESystem(eqs, t, [u], [], discrete_events=[[4.0, ]=>(affect1!, [u], [], nothing)])
1212
prob = ODEProblem(sys, [u=> 10.0], (0, 10.0))
@@ -15,8 +15,8 @@ i4 = findfirst(==(4.0), sol[:t])
1515
@test sol.u[i4+1][1] > 10.0
1616

1717
# context
18-
function affect2!(integ, ctx; u)
19-
integ.u[u] += ctx[1]
18+
function affect2!(integ, u, p, ctx)
19+
integ.u[u.u] += ctx[1]
2020
ctx[1] *= 2
2121
end
2222
ctx1 = [10.0, ]
@@ -30,9 +30,9 @@ i8 = findfirst(==(8.0), sol[:t])
3030
@test ctx1[1] == 40.0
3131

3232
# parameter
33-
function affect3!(integ, ctx; u, a)
34-
integ.u[u] += integ.p[a]
35-
integ.p[a] *= 2
33+
function affect3!(integ, u, p, ctx)
34+
integ.u[u.u] += integ.p[p.a]
35+
integ.p[p.a] *= 2
3636
end
3737

3838
@parameters a = 10.0
@@ -46,9 +46,9 @@ i8 = findfirst(==(8.0), sol[:t])
4646
@test sol.u[i8+1][1] > 20.0
4747

4848
# rename parameter
49-
function affect3!(integ, ctx; u, b)
50-
integ.u[u] += integ.p[b]
51-
integ.p[b] *= 2
49+
function affect3!(integ, u, p, ctx)
50+
integ.u[u.u] += integ.p[p.b]
51+
integ.p[p.b] *= 2
5252
end
5353

5454
@named sys = ODESystem(eqs, t, [u], [a], discrete_events=[[4.0, 8.0]=>(affect3!, [u], [a=> :b], nothing)])
@@ -61,8 +61,21 @@ i8 = findfirst(==(8.0), sol[:t])
6161
@test sol.u[i8+1][1] > 20.0
6262

6363
# same name
64-
@test_throws ErrorException ODESystem(eqs, t, [u], [a], discrete_events=[[4.0, 8.0]=>(affect3!, [u], [a=> :u], nothing)]; name=:sys)
64+
@variables v(t)
65+
@test_throws ErrorException ODESystem(eqs, t, [u], [a], discrete_events=[[4.0, 8.0]=>(affect3!, [u, v => :u], [a], nothing)]; name=:sys)
6566

67+
@named resistor = ODESystem(D(v) ~ v, t, [v], [])
6668

69+
# nested namespace
70+
ctx = [0]
71+
function affect4!(integ, u, p, ctx)
72+
ctx[1] += 1
73+
@test u.resistor₊v == 1
74+
end
75+
s1 = compose(ODESystem(Equation[], t, [], [], name=:s1, discrete_events=1.0=>(affect4!, [resistor.v], [], ctx)), resistor)
76+
s2 = structural_simplify(s1)
77+
prob = ODEProblem(s2, [resistor.v=> 10.0], (0, 2.01))
78+
sol = solve(prob, Tsit5())
79+
@test ctx[1] == 2
6780

6881

0 commit comments

Comments
 (0)