Skip to content

Commit d3a7e71

Browse files
feat: cache intermediate results for observed_equations_used_by
1 parent b86cc52 commit d3a7e71

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/utils.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,14 @@ function observed_dependency_graph(eqs::Vector{Equation})
811811
return DiCMOBiGraph{false}(graph, matching)
812812
end
813813

814+
abstract type ObservedGraphCacheKey end
815+
816+
struct ObservedGraphCache
817+
graph::DiCMOBiGraph{false, Int, BipartiteGraph{Int, Nothing},
818+
Matching{Unassigned, Vector{Union{Unassigned, Int}}}}
819+
obsvar_to_idx::Dict{Any, Int}
820+
end
821+
814822
"""
815823
$(TYPEDSIGNATURES)
816824
@@ -831,8 +839,19 @@ Keyword arguments:
831839
"""
832840
function observed_equations_used_by(sys::AbstractSystem, exprs;
833841
involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = [])
834-
obsvars = getproperty.(obs, :lhs)
835-
graph = observed_dependency_graph(obs)
842+
if iscomplete(sys) && obs == observed(sys)
843+
cache = getmetadata(sys, MutableCacheKey, nothing)
844+
obs_graph_cache = get!(cache, ObservedGraphCacheKey) do
845+
obsvar_to_idx = Dict{Any, Int}([eq.lhs => i for (i, eq) in enumerate(obs)])
846+
graph = observed_dependency_graph(obs)
847+
return ObservedGraphCache(graph, obsvar_to_idx)
848+
end
849+
@unpack obsvar_to_idx, graph = obs_graph_cache
850+
else
851+
obsvar_to_idx = Dict([eq.lhs => i for (i, eq) in enumerate(obs)])
852+
graph = observed_dependency_graph(obs)
853+
end
854+
836855
if !(available_vars isa Set)
837856
available_vars = Set(available_vars)
838857
end
@@ -841,7 +860,9 @@ function observed_equations_used_by(sys::AbstractSystem, exprs;
841860
for sym in involved_vars
842861
sym in available_vars && continue
843862
arrsym = iscall(sym) && operation(sym) === getindex ? arguments(sym)[1] : nothing
844-
idx = findfirst(v -> isequal(v, sym) || isequal(v, arrsym), obsvars)
863+
idx = @something(get(obsvar_to_idx, sym, nothing),
864+
get(obsvar_to_idx, arrsym, nothing),
865+
Some(nothing))
845866
idx === nothing && continue
846867
idx in obsidxs && continue
847868
parents = dfs_parents(graph, idx)

0 commit comments

Comments
 (0)