Skip to content

Commit c2f0d67

Browse files
Merge pull request #2549 from AayushSabharwal/as/wrap-array-vars-fix
fix: do not filter array unknowns in wrap_array_vars
2 parents a66f2b8 + ae65529 commit c2f0d67

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

src/systems/abstractsystem.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,10 @@ end
198198

199199
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
200200
isscalar = !(exprs isa AbstractArray)
201-
allvars = if isscalar
202-
Set(get_variables(exprs))
203-
else
204-
union(get_variables.(exprs)...)
205-
end
206201
array_vars = Dict{Any, AbstractArray{Int}}()
207202
for (j, x) in enumerate(dvs)
208203
if istree(x) && operation(x) == getindex
209204
arg = arguments(x)[1]
210-
any(isequal(arg), allvars) || continue
211205
inds = get!(() -> Int[], array_vars, arg)
212206
push!(inds, j)
213207
end

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,10 @@ function build_explicit_observed_function(sys, ts;
380380
ps = full_parameters(sys),
381381
op = Operator,
382382
throw = true)
383-
if (isscalar = !(ts isa AbstractVector))
383+
if (isscalar = symbolic_type(ts) !== NotSymbolic())
384384
ts = [ts]
385385
end
386-
ts = unwrap.(Symbolics.scalarize(ts))
386+
ts = unwrap.(ts)
387387

388388
vars = Set()
389389
foreach(v -> vars!(vars, v; op), ts)
@@ -399,9 +399,17 @@ function build_explicit_observed_function(sys, ts;
399399
end
400400

401401
sts = Set(unknowns(sys))
402+
sts = union(sts,
403+
Set(arguments(st)[1] for st in sts if istree(st) && operation(st) === getindex))
404+
402405
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
403406
param_set = Set(parameters(sys))
407+
param_set = union(param_set,
408+
Set(arguments(p)[1] for p in param_set if istree(p) && operation(p) === getindex))
404409
param_set_ns = Set(unknowns(sys, p) for p in parameters(sys))
410+
param_set_ns = union(param_set_ns,
411+
Set(arguments(p)[1]
412+
for p in param_set_ns if istree(p) && operation(p) === getindex))
405413
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
406414
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
407415

@@ -470,9 +478,9 @@ function build_explicit_observed_function(sys, ts;
470478
pre = get_postprocess_fbody(sys)
471479

472480
ex = Func(args, [],
473-
pre(Let(obsexprs,
474-
isscalar ? ts[1] : MakeArray(ts, output_type),
475-
false))) |> toexpr
481+
pre(Let(obsexprs,
482+
isscalar ? ts[1] : MakeArray(ts, output_type),
483+
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
476484
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
477485
end
478486

test/odesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)]
543543
prob = ODEProblem(
544544
outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)])
545545
@test_nowarn solve(prob, Tsit5())
546+
obsfn = ModelingToolkit.build_explicit_observed_function(
547+
outersys, bar(3outersys.sys.ms, 3outersys.sys.p))
548+
@test_nowarn obsfn(sol.u[1], prob.p..., sol.t[1])
546549

547550
# x/x
548551
@variables x(t)

0 commit comments

Comments
 (0)