Skip to content

Commit 62478f8

Browse files
committed
Fix #1488
1 parent dd6bbd3 commit 62478f8

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

src/structural_transformation/codegen.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,20 +393,40 @@ function build_observed_function(
393393

394394
required_algvars = Set(intersect(algvars, vars))
395395
obs = observed(sys)
396-
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
397-
# FIXME: this is a rather rough estimate of dependencies.
398-
maxidx = 0
396+
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
397+
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)
398+
namespaced_to_sts = Dict(states(sys, x) => x for x in states(sys))
399399
sts = Set(states(sys))
400+
401+
# FIXME: This is a rather rough estimate of dependencies. We assume
402+
# the expression depends on everything before the `maxidx`.
403+
subs = Dict()
404+
maxidx = 0
400405
for (i, s) in enumerate(dep_vars)
401406
idx = get(observed_idx, s, nothing)
402-
if idx === nothing
403-
if !(s in sts)
407+
if idx !== nothing
408+
idx > maxidx && (maxidx = idx)
409+
else
410+
s′ = get(namespaced_to_obs, s, nothing)
411+
if s′ !== nothing
412+
subs[s] = s′
413+
s = s′
414+
idx = get(observed_idx, s, nothing)
415+
end
416+
if idx !== nothing
417+
idx > maxidx && (maxidx = idx)
418+
elseif !(s in sts)
419+
s′ = get(namespaced_to_sts, s, nothing)
420+
if s′ !== nothing
421+
subs[s] = s′
422+
continue
423+
end
404424
throw(ArgumentError("$s is either an observed nor a state variable."))
405425
end
406426
continue
407427
end
408-
idx > maxidx && (maxidx = idx)
409428
end
429+
ts = map(t->substitute(t, subs), ts)
410430
vs = Set()
411431
for idx in 1:maxidx
412432
vars!(vs, obs[idx].rhs)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
357357
end
358358

359359
M = calculate_massmatrix(sys)
360-
360+
361361
_M = if sparse && !(u0 === nothing || M === I)
362362
SparseArrays.sparse(M)
363363
elseif u0 === nothing || M === I

src/systems/diffeqs/odesystem.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,39 @@ function build_explicit_observed_function(
279279

280280
obs = observed(sys)
281281
sts = Set(states(sys))
282-
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
282+
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
283+
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)
284+
namespaced_to_sts = Dict(states(sys, x) => x for x in states(sys))
283285

284286
# FIXME: This is a rather rough estimate of dependencies. We assume
285287
# the expression depends on everything before the `maxidx`.
288+
subs = Dict()
286289
maxidx = 0
287290
for (i, s) in enumerate(dep_vars)
288291
idx = get(observed_idx, s, nothing)
289-
if idx === nothing
290-
if !(s in sts)
292+
if idx !== nothing
293+
idx > maxidx && (maxidx = idx)
294+
else
295+
s′ = get(namespaced_to_obs, s, nothing)
296+
if s′ !== nothing
297+
subs[s] = s′
298+
s = s′
299+
idx = get(observed_idx, s, nothing)
300+
end
301+
if idx !== nothing
302+
idx > maxidx && (maxidx = idx)
303+
elseif !(s in sts)
304+
s′ = get(namespaced_to_sts, s, nothing)
305+
if s′ !== nothing
306+
subs[s] = s′
307+
continue
308+
end
291309
throw(ArgumentError("$s is either an observed nor a state variable."))
292310
end
293311
continue
294312
end
295-
idx > maxidx && (maxidx = idx)
296313
end
314+
ts = map(t->substitute(t, subs), ts)
297315
obsexprs = map(eq -> eq.lhseq.rhs, obs[1:maxidx])
298316

299317
dvs = DestructuredArgs(states(sys), inbounds=!checkbounds)

0 commit comments

Comments
 (0)