Skip to content

Commit 0774860

Browse files
feat: use build_function_wrapper in build_explicit_observed_function
1 parent 07f8206 commit 0774860

File tree

1 file changed

+62
-183
lines changed

1 file changed

+62
-183
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 62 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -429,215 +429,94 @@ function build_explicit_observed_function(sys, ts;
429429
param_only = false,
430430
op = Operator,
431431
throw = true,
432-
mkarray = MakeArray)
432+
mkarray = nothing)
433433
is_tuple = ts isa Tuple
434434
if is_tuple
435435
ts = collect(ts)
436+
output_type = Tuple
436437
end
437-
if (isscalar = symbolic_type(ts) !== NotSymbolic())
438-
ts = [ts]
439-
end
440-
ts = unwrap.(ts)
441-
issplit = has_index_cache(sys) && get_index_cache(sys) !== nothing
442-
if is_dde(sys)
443-
if issplit
444-
ts = map(
445-
x -> delay_to_function(
446-
sys, x; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG),
447-
ts)
448-
else
449-
ts = map(x -> delay_to_function(sys, x), ts)
450-
end
451-
end
452-
453-
vars = Set()
454-
foreach(v -> vars!(vars, v; op), ts)
455-
ivs = independent_variables(sys)
456-
dep_vars = scalarize(setdiff(vars, ivs))
457438

458-
obs = observed(sys)
459-
if param_only
460-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
461-
obs = filter(obs) do eq
462-
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
463-
end
464-
else
465-
obs = Equation[]
439+
allsyms = all_symbols(sys)
440+
function symbol_to_symbolic(sym)
441+
sym isa Symbol || return sym
442+
idx = findfirst(x -> (hasname(x) ? getname(x) : Symbol(x)) == sym, allsyms)
443+
idx === nothing && return sym
444+
sym = allsyms[idx]
445+
if iscall(sym) && operation(sym) == getindex
446+
sym = arguments(sym)[1]
466447
end
448+
return sym
467449
end
468-
469-
cs = collect_constants(obs)
470-
if !isempty(cs) > 0
471-
cmap = map(x -> x => getdefault(x), cs)
472-
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
450+
if symbolic_type(ts) == NotSymbolic() && ts isa AbstractArray
451+
ts = map(symbol_to_symbolic, ts)
452+
else
453+
ts = symbol_to_symbolic(ts)
473454
end
474455

475-
sts = param_only ? Set() : Set(unknowns(sys))
476-
sts = param_only ? Set() :
477-
union(sts,
478-
Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex))
479-
480-
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
481-
param_set = Set(full_parameters(sys))
482-
param_set = union(param_set,
483-
Set(arguments(p)[1] for p in param_set if iscall(p) && operation(p) === getindex))
484-
param_set_ns = Set(unknowns(sys, p) for p in full_parameters(sys))
485-
param_set_ns = union(param_set_ns,
486-
Set(arguments(p)[1]
487-
for p in param_set_ns if iscall(p) && operation(p) === getindex))
488-
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
489-
namespaced_to_sts = param_only ? Dict() :
490-
Dict(unknowns(sys, x) => x for x in unknowns(sys))
491-
492-
# FIXME: This is a rather rough estimate of dependencies. We assume
493-
# the expression depends on everything before the `maxidx`.
494-
subs = Dict()
495-
maxidx = 0
496-
for s in dep_vars
497-
if s in param_set || s in param_set_ns ||
498-
iscall(s) &&
499-
operation(s) === getindex &&
500-
(arguments(s)[1] in param_set || arguments(s)[1] in param_set_ns)
501-
continue
456+
vs = ModelingToolkit.vars(ts; op)
457+
namespace_subs = Dict()
458+
ns_map = Dict{Any, Any}(renamespace(sys, eq.lhs) => eq.lhs for eq in observed(sys))
459+
for sym in unknowns(sys)
460+
ns_map[renamespace(sys, sym)] = sym
461+
if iscall(sym) && operation(sym) === getindex
462+
ns_map[renamespace(sys, arguments(sym)[1])] = arguments(sym)[1]
502463
end
503-
idx = get(observed_idx, s, nothing)
504-
if idx !== nothing
505-
idx > maxidx && (maxidx = idx)
506-
else
507-
s′ = get(namespaced_to_obs, s, nothing)
508-
if s′ !== nothing
509-
subs[s] = s′
510-
s = s′
511-
idx = get(observed_idx, s, nothing)
512-
end
513-
if idx !== nothing
514-
idx > maxidx && (maxidx = idx)
515-
elseif !(s in sts)
516-
s′ = get(namespaced_to_sts, s, nothing)
517-
if s′ !== nothing
518-
subs[s] = s′
519-
continue
520-
end
521-
if throw
522-
Base.throw(ArgumentError("$s is neither an observed nor an unknown variable."))
523-
else
524-
# TODO: return variables that don't exist in the system.
525-
return nothing
526-
end
527-
end
528-
continue
464+
end
465+
for sym in full_parameters(sys)
466+
ns_map[renamespace(sys, sym)] = sym
467+
if iscall(sym) && operation(sym) === getindex
468+
ns_map[renamespace(sys, arguments(sym)[1])] = arguments(sym)[1]
529469
end
530470
end
531-
ts = map(t -> substitute(t, subs), ts)
532-
obsexprs = []
533-
534-
for i in 1:maxidx
535-
eq = obs[i]
536-
if is_dde(sys)
537-
eq = delay_to_function(
538-
sys, eq; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG)
471+
allsyms = Set(all_symbols(sys))
472+
for var in vs
473+
var = unwrap(var)
474+
newvar = get(ns_map, var, nothing)
475+
if newvar !== nothing
476+
namespace_subs[var] = newvar
539477
end
540-
lhs = eq.lhs
541-
rhs = eq.rhs
542-
push!(obsexprs, lhs rhs)
543478
end
479+
ts = fast_substitute(ts, namespace_subs)
544480

545-
if inputs !== nothing
546-
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
547-
end
548-
_ps = ps
549-
if ps isa Tuple
550-
ps = DestructuredArgs.(unwrap.(ps), inbounds = !checkbounds)
551-
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
552-
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), unwrap.(ps)))
553-
if isempty(ps) && inputs !== nothing
554-
ps = (:EMPTY,)
481+
obsfilter = if param_only
482+
if is_split(sys)
483+
let ic = get_index_cache(sys)
484+
eq -> !(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
485+
end
486+
else
487+
Returns(false)
555488
end
556489
else
557-
ps = (DestructuredArgs(unwrap.(ps), inbounds = !checkbounds),)
490+
Returns(true)
558491
end
559-
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
560-
if is_dde(sys)
561-
dvs = (dvs, DDE_HISTORY_FUN)
492+
dvs = if param_only
493+
()
562494
else
563-
dvs = (dvs,)
495+
(unknowns(sys),)
564496
end
565-
p_start = param_only ? 1 : (length(dvs) + 1)
566497
if inputs === nothing
567-
args = param_only ? [ps..., ivs...] : [dvs..., ps..., ivs...]
498+
inputs = ()
568499
else
569-
inputs = unwrap.(inputs)
570-
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
571-
args = param_only ? [ipts, ps..., ivs...] : [dvs..., ipts, ps..., ivs...]
572-
p_start += 1
573-
end
574-
pre = get_postprocess_fbody(sys)
575-
576-
array_wrapper = if param_only
577-
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing, inputs, history = is_dde(sys)) .∘
578-
wrap_parameter_dependencies(sys, isscalar)
579-
else
580-
wrap_array_vars(sys, ts; ps = _ps, inputs, history = is_dde(sys)) .∘
581-
wrap_parameter_dependencies(sys, isscalar)
582-
end
583-
mtkparams_wrapper = wrap_mtkparameters(sys, isscalar, p_start)
584-
if mtkparams_wrapper isa Tuple
585-
oop_mtkp_wrapper = mtkparams_wrapper[1]
586-
else
587-
oop_mtkp_wrapper = mtkparams_wrapper
500+
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
501+
inputs = (inputs,)
588502
end
589-
590-
# Need to keep old method of building the function since it uses `output_type`,
591-
# which can't be provided to `build_function`
592-
return_value = if isscalar
593-
ts[1]
594-
elseif is_tuple
595-
MakeTuple(Tuple(ts))
503+
ps = reorder_parameters(sys, ps)
504+
iv = if is_time_dependent(sys)
505+
(get_iv(sys),)
596506
else
597-
mkarray(ts, output_type)
598-
end
599-
oop_fn = Func(args, [],
600-
pre(Let(obsexprs,
601-
return_value,
602-
false)), [Expr(:meta, :propagate_inbounds)]) |> array_wrapper[1] |>
603-
oop_mtkp_wrapper |> toexpr
604-
605-
if !checkbounds
606-
oop_fn.args[end] = quote
607-
@inbounds begin
608-
$(oop_fn.args[end])
609-
end
610-
end
611-
end
612-
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
613-
614-
if !isscalar
615-
iip_fn = build_function(ts,
616-
args...;
617-
postprocess_fbody = pre,
618-
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
619-
wrap_assignments(isscalar, obsexprs),
620-
expression = Val{true})[2]
621-
if !checkbounds
622-
iip_fn.args[end] = quote
623-
@inbounds begin
624-
$(iip_fn.args[end])
625-
end
626-
end
627-
end
628-
iip_fn.args[end] = quote
629-
$(Expr(:meta, :propagate_inbounds))
630-
$(iip_fn.args[end])
631-
end
632-
633-
if !expression
634-
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)
635-
end
507+
()
636508
end
637-
if isscalar || !return_inplace
638-
return oop_fn
509+
args = (dvs..., inputs..., ps..., iv...)
510+
p_start = length(dvs) + length(inputs) + 1
511+
p_end = length(dvs) + length(inputs) + length(ps)
512+
fns = build_function_wrapper(
513+
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
514+
output_type, mkarray, try_namespaced = true, expression = Val{true})
515+
if fns isa Tuple
516+
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
517+
return return_inplace ? (oop, iip) : oop
639518
else
640-
return oop_fn, iip_fn
519+
return eval_or_rgf(fns; eval_expression, eval_module)
641520
end
642521
end
643522

0 commit comments

Comments
 (0)