Skip to content

Commit 9f0433b

Browse files
Merge pull request #2384 from chriselrod/cachehash
Cache hash of ConnectionElement
2 parents 26e5248 + a01f0a6 commit 9f0433b

File tree

9 files changed

+71
-27
lines changed

9 files changed

+71
-27
lines changed

src/systems/abstractsystem.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,14 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187187
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188188
return unwrap(sym) in 1:length(unknown_states(sys))
189189
end
190-
return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym))
190+
return any(isequal(sym), unknown_states(sys)) ||
191+
hasname(sym) && is_variable(sys, getname(sym))
191192
end
192193

193194
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
194-
return any(isequal(sym), getname.(unknown_states(sys))) || count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
195+
return any(isequal(sym), getname.(unknown_states(sys))) ||
196+
count('', string(sym)) == 1 &&
197+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
195198
end
196199

197200
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
@@ -224,12 +227,14 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
224227
return unwrap(sym) in 1:length(parameters(sys))
225228
end
226229

227-
return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym))
230+
return any(isequal(sym), parameters(sys)) ||
231+
hasname(sym) && is_parameter(sys, getname(sym))
228232
end
229233

230234
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
231235
return any(isequal(sym), getname.(parameters(sys))) ||
232-
count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
236+
count('', string(sym)) == 1 &&
237+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
233238
end
234239

235240
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
@@ -270,7 +275,8 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys
270275
end
271276

272277
function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
273-
return !is_variable(sys, sym) && !is_parameter(sys, sym) && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
278+
return !is_variable(sys, sym) && !is_parameter(sys, sym) &&
279+
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
274280
end
275281

276282
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
@@ -621,13 +627,25 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
621627
O
622628
end
623629
end
624-
630+
_nonum(@nospecialize x) = x isa Num ? x.val : x
625631
function states(sys::AbstractSystem)
626632
sts = get_states(sys)
627633
systems = get_systems(sys)
628-
unique(isempty(systems) ?
629-
sts :
630-
[sts; reduce(vcat, namespace_variables.(systems))])
634+
nonunique_states = if isempty(systems)
635+
sts
636+
else
637+
system_states = reduce(vcat, namespace_variables.(systems))
638+
isempty(sts) ? system_states : [sts; system_states]
639+
end
640+
isempty(nonunique_states) && return nonunique_states
641+
# `Vector{Any}` is incompatible with the `SymbolicIndexingInterface`, which uses
642+
# `elsymtype = symbolic_type(eltype(_arg))`
643+
# which inappropriately returns `NotSymbolic()`
644+
if nonunique_states isa Vector{Any}
645+
nonunique_states = _nonum.(nonunique_states)
646+
end
647+
@assert typeof(nonunique_states) !== Vector{Any}
648+
unique(nonunique_states)
631649
end
632650

633651
function parameters(sys::AbstractSystem)

src/systems/connectors.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,39 @@ struct ConnectionElement
147147
sys::LazyNamespace
148148
v::Any
149149
isouter::Bool
150+
h::UInt
150151
end
151-
Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v))
152-
function Base.hash(l::ConnectionElement, salt::UInt)
153-
hash(nameof(l.sys)) hash(l.v) hash(l.isouter) salt
152+
function _hash_impl(sys, v, isouter)
153+
hashcore = hash(nameof(sys)) hash(getname(v))
154+
hashouter = isouter ? hash(true) : hash(false)
155+
hashcore hashouter
156+
end
157+
function ConnectionElement(sys::LazyNamespace, v, isouter::Bool)
158+
ConnectionElement(sys, v, isouter, _hash_impl(sys, v, isouter))
154159
end
160+
Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v))
155161
Base.isequal(l1::ConnectionElement, l2::ConnectionElement) = l1 == l2
156162
function Base.:(==)(l1::ConnectionElement, l2::ConnectionElement)
157163
nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v) && l1.isouter == l2.isouter
158164
end
165+
function Base.hash(e::ConnectionElement, salt::UInt)
166+
@inbounds begin
167+
@boundscheck begin
168+
@assert e.h === _hash_impl(e.sys, e.v, e.isouter)
169+
end
170+
end
171+
e.h salt
172+
end
159173
namespaced_var(l::ConnectionElement) = states(l, l.v)
160174
states(l::ConnectionElement, v) = states(copy(l.sys), v)
161175

176+
function withtrueouter(e::ConnectionElement)
177+
e.isouter && return e
178+
# we undo the xor
179+
newhash = (e.h hash(false)) hash(true)
180+
ConnectionElement(e.sys, e.v, true, newhash)
181+
end
182+
162183
struct ConnectionSet
163184
set::Vector{ConnectionElement} # namespace.sys, var, isouter
164185
end
@@ -353,9 +374,7 @@ function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
353374
merged = false
354375
for (j, cset) in enumerate(csets)
355376
if allouter
356-
cset = ConnectionSet(map(cset.set) do e
357-
@set! e.isouter = true
358-
end)
377+
cset = ConnectionSet(map(withtrueouter, cset.set))
359378
end
360379
idx = nothing
361380
for e in cset.set
@@ -390,7 +409,7 @@ end
390409

391410
function generate_connection_equations_and_stream_connections(csets::AbstractVector{
392411
<:ConnectionSet,
393-
})
412+
})
394413
eqs = Equation[]
395414
stream_connections = ConnectionSet[]
396415

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ function (ratemap::JumpSysMajParamMapper{
491491
U,
492492
V,
493493
W,
494-
})(params) where {U <: AbstractArray,
494+
})(params) where {U <: AbstractArray,
495495
V <: AbstractArray, W}
496496
updateparams!(ratemap, params)
497497
[convert(W, value(substitute(paramexpr, ratemap.subdict)))

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ function initializesystem(sys::ODESystem; name = nameof(sys), kwargs...)
5252
kwargs...)
5353

5454
return sys_nl
55-
end
55+
end

src/systems/optimization/modelingtoolkitize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
2727
for i in 1:num_cons
2828
if !isinf(prob.lcons[i])
2929
if prob.lcons[i] != prob.ucons[i]
30-
push!(cons, prob.lcons[i] lhs[i])
30+
push!(cons, prob.lcons[i] lhs[i])
3131
else
3232
push!(cons, lhs[i] ~ prob.ucons[i])
3333
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
269269
expression = Val{false})
270270

271271
obj_expr = subs_constants(objective(sys))
272-
272+
273273
if grad
274274
grad_oop, grad_iip = generate_gradient(sys, checkbounds = checkbounds,
275275
linenumbers = linenumbers,

src/variables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,4 +463,4 @@ function get_default_or_guess(x)
463463
else
464464
return getguess(x)
465465
end
466-
end
466+
end

test/nonlinearsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ end
237237
@testset "Initialization System" begin
238238
# Define the Lotka Volterra system which begins at steady state
239239
@parameters t
240-
pars = @parameters a=1.5 b=1.0 c=3.0 d=1.0 dx_ss = 1e-5
240+
pars = @parameters a=1.5 b=1.0 c=3.0 d=1.0 dx_ss=1e-5
241241

242242
vars = @variables begin
243243
dx(t),
@@ -274,4 +274,4 @@ end
274274

275275
# Confirm for all the states of the simplified system
276276
@test all(.≈(sol[states(sys_simple)], [1e-5 / 1.5, 0]; atol = 1e-8))
277-
end
277+
end

test/optimizationsystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,17 @@ end
289289
@parameters a b
290290
loss = (a - x)^2 + b * (y - x^2)^2
291291
@named sys = OptimizationSystem(loss, [x, y], [a, b], constraints = [x^2 + y^2 0.0])
292-
@test_throws ArgumentError OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0], lcons = [0.0])
293-
@test_throws ArgumentError OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0], ucons = [0.0])
292+
@test_throws ArgumentError OptimizationProblem(sys,
293+
[x => 0.0, y => 0.0],
294+
[a => 1.0, b => 100.0],
295+
lcons = [0.0])
296+
@test_throws ArgumentError OptimizationProblem(sys,
297+
[x => 0.0, y => 0.0],
298+
[a => 1.0, b => 100.0],
299+
ucons = [0.0])
294300

295301
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0])
296302
@test prob.f.expr isa Symbolics.Symbolic
297-
@test all(prob.f.cons_expr[i].lhs isa Symbolics.Symbolic for i in 1:length(prob.f.cons_expr))
298-
end
303+
@test all(prob.f.cons_expr[i].lhs isa Symbolics.Symbolic
304+
for i in 1:length(prob.f.cons_expr))
305+
end

0 commit comments

Comments
 (0)