Skip to content

Commit 5b62aee

Browse files
fix: improve validation of variables in build_explicit_observed_function
1 parent f5176c6 commit 5b62aee

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,15 +476,15 @@ function build_explicit_observed_function(sys, ts;
476476
end
477477
end
478478
allsyms = Set(all_symbols(sys))
479+
iv = has_iv(sys) ? get_iv(sys) : nothing
479480
for var in vs
480481
var = unwrap(var)
481482
newvar = get(ns_map, var, nothing)
482483
if newvar !== nothing
483484
namespace_subs[var] = newvar
484485
var = newvar
485486
end
486-
if throw && !(var in allsyms) &&
487-
(!iscall(var) || operation(var) !== getindex || !(arguments(var)[1] in allsyms))
487+
if throw && !var_in_varlist(var, allsyms, iv)
488488
Base.throw(ArgumentError("Symbol $var is not present in the system."))
489489
end
490490
end

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,3 +1267,22 @@ function symbol_to_symbolic(sys::AbstractSystem, sym; allsyms = all_symbols(sys)
12671267
end
12681268
return sym
12691269
end
1270+
1271+
"""
1272+
$(TYPEDSIGNATURES)
1273+
1274+
Check if `var` is present in `varlist`. `iv` is the independent variable of the system,
1275+
and should be `nothing` if not applicable.
1276+
"""
1277+
function var_in_varlist(var, varlist::AbstractSet, iv)
1278+
var = unwrap(var)
1279+
# simple case
1280+
return var in varlist ||
1281+
# indexed array symbolic, unscalarized array present
1282+
(iscall(var) && operation(var) === getindex && arguments(var)[1] in varlist) ||
1283+
# unscalarized sized array symbolic, all scalarized elements present
1284+
(symbolic_type(var) == ArraySymbolic() && is_sized_array_symbolic(var) &&
1285+
all(x -> x in varlist, collect(var))) ||
1286+
# delayed variables
1287+
(isdelay(var, iv) && var_in_varlist(operation(var)(iv), varlist, iv))
1288+
end

0 commit comments

Comments
 (0)