Skip to content

Commit e477c2f

Browse files
committed
Invalidate jac and friends' caches when necessary
1 parent 3435e1e commit e477c2f

File tree

7 files changed

+31
-14
lines changed

7 files changed

+31
-14
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ function dae_index_lowering(sys::ODESystem; kwargs...)
149149
s = get_structure(sys)
150150
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
151151
sys, var_eq_matching, eq_to_diff = pantelides!(sys; kwargs...)
152-
return pantelides_reassemble(sys, eq_to_diff, var_eq_matching)
152+
return invalidate_cache!(pantelides_reassemble(sys, eq_to_diff, var_eq_matching))
153153
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function tearing(sys; simplify=false)
150150
sys = init_for_tearing(sys)
151151
var_eq_matching = tear_graph(sys)
152152

153-
tearing_reassemble(sys, var_eq_matching; simplify=simplify)
153+
invalidate_cache!(tearing_reassemble(sys, var_eq_matching; simplify=simplify))
154154
end
155155

156156
function init_for_tearing(sys)

src/systems/abstractsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,23 @@ for prop in [
232232
end
233233
end
234234

235+
const EMPTY_TGRAD = Vector{Num}(undef, 0)
236+
const EMPTY_JAC = Matrix{Num}(undef, 0, 0)
237+
function invalidate_cache!(sys::AbstractSystem)
238+
if isdefined(sys, :tgrad)
239+
sys.tgrad[] = EMPTY_TGRAD
240+
elseif isdefined(sys, :jac)
241+
sys.jac[] = EMPTY_JAC
242+
elseif isdefined(sys, :ctrl_jac)
243+
sys.jac[] = EMPTY_JAC
244+
elseif isdefined(sys, :Wfact)
245+
sys.jac[] = EMPTY_JAC
246+
elseif isdefined(sys, :Wfact_t)
247+
sys.jac[] = EMPTY_JAC
248+
end
249+
return sys
250+
end
251+
235252
Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field} = getfield(obj, field)
236253
@generated function ConstructionBase.setproperties(obj::AbstractSystem, patch::NamedTuple)
237254
if issubset(fieldnames(patch), fieldnames(obj))

src/systems/alias_elimination.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function alias_elimination(sys)
6868
@set! sys.states = newstates
6969
@set! sys.observed = [observed(sys); [lhs ~ rhs for (lhs, rhs) in pairs(subs)]]
7070
@set! sys.structure = nothing
71-
return sys
71+
return invalidate_cache!(sys)
7272
end
7373

7474
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ function ODESystem(
147147
process_variables!(var_to_name, defaults, ps′)
148148
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
149149

150-
tgrad = RefValue(Vector{Num}(undef, 0))
151-
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
152-
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
153-
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
154-
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
150+
tgrad = RefValue(EMPTY_TGRAD)
151+
jac = RefValue{Any}(EMPTY_JAC)
152+
ctrl_jac = RefValue{Any}(EMPTY_JAC)
153+
Wfact = RefValue(EMPTY_JAC)
154+
Wfact_t = RefValue(EMPTY_JAC)
155155
sysnames = nameof.(systems)
156156
if length(unique(sysnames)) != length(sysnames)
157157
throw(ArgumentError("System names must be unique."))

src/systems/diffeqs/sdesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
130130
process_variables!(var_to_name, defaults, ps′)
131131
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
132132

133-
tgrad = RefValue(Vector{Num}(undef, 0))
134-
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
135-
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
136-
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
137-
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
133+
tgrad = RefValue(EMPTY_TGRAD)
134+
jac = RefValue{Any}(EMPTY_JAC)
135+
ctrl_jac = RefValue{Any}(EMPTY_JAC)
136+
Wfact = RefValue(EMPTY_JAC)
137+
Wfact_t = RefValue(EMPTY_JAC)
138138
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, checks = checks)
139139
end
140140

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function NonlinearSystem(eqs, states, ps;
9494
if length(unique(sysnames)) != length(sysnames)
9595
throw(ArgumentError("System names must be unique."))
9696
end
97-
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
97+
jac = RefValue{Any}(EMPTY_JAC)
9898
defaults = todict(defaults)
9999
defaults = Dict{Any,Any}(value(k) => value(v) for (k, v) in pairs(defaults))
100100

0 commit comments

Comments
 (0)