Skip to content

Commit 9a896bc

Browse files
Merge pull request #3777 from AayushSabharwal/as/isapprox
feat: implement `isapprox` for systems
2 parents 7ad97b2 + 8fd7ac9 commit 9a896bc

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,9 @@ function refreshed_metadata(meta::Base.ImmutableDict)
820820
end
821821
newmeta = Base.ImmutableDict(newmeta, k => v)
822822
end
823+
if !haskey(newmeta, MutableCacheKey)
824+
newmeta = Base.ImmutableDict(newmeta, MutableCacheKey => MutableCacheT())
825+
end
823826
return newmeta
824827
end
825828

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ function SciMLBase.late_binding_update_u0_p(
738738
end
739739
newp = setp_oop(sys, syms)(newp, vals)
740740
else
741+
allsyms = nothing
741742
# if `p` is not provided or is symbolic
742743
p === missing || eltype(p) <: Pair || return newu0, newp
743744
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -755,6 +756,9 @@ function SciMLBase.late_binding_update_u0_p(
755756
if eltype(p) <: Pair
756757
syms = []
757758
vals = []
759+
if allsyms === nothing
760+
allsyms = all_symbols(sys)
761+
end
758762
for (k, v) in p
759763
v === nothing && continue
760764
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue

src/systems/system.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
411411
end
412412
metadata = meta
413413
end
414-
metadata = Base.ImmutableDict(metadata, MutableCacheKey => MutableCacheT())
414+
metadata = refreshed_metadata(metadata)
415415
System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints,
416416
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
417417
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
@@ -1097,3 +1097,52 @@ function supports_initialization(sys::System)
10971097
return isempty(jumps(sys)) && _iszero(cost(sys)) &&
10981098
isempty(constraints(sys))
10991099
end
1100+
1101+
safe_eachrow(::Nothing) = nothing
1102+
safe_eachrow(x::AbstractArray) = eachrow(x)
1103+
1104+
safe_issetequal(::Nothing, ::Nothing) = true
1105+
safe_issetequal(::Nothing, x) = false
1106+
safe_issetequal(x, ::Nothing) = false
1107+
safe_issetequal(x, y) = issetequal(x, y)
1108+
1109+
"""
1110+
$(TYPEDSIGNATURES)
1111+
1112+
Check if two systems are about equal, to the extent that ModelingToolkit.jl supports. Note
1113+
that if this returns `true`, the systems are not guaranteed to be exactly equivalent
1114+
(unless `sysa === sysb`) but are highly likely to represent a similar mathematical problem.
1115+
If this returns `false`, the systems are very likely to be different.
1116+
"""
1117+
function Base.isapprox(sysa::System, sysb::System)
1118+
sysa === sysb && return true
1119+
return nameof(sysa) == nameof(sysb) &&
1120+
isequal(get_iv(sysa), get_iv(sysb)) &&
1121+
issetequal(get_eqs(sysa), get_eqs(sysb)) &&
1122+
safe_issetequal(
1123+
safe_eachrow(get_noise_eqs(sysa)), safe_eachrow(get_noise_eqs(sysb))) &&
1124+
issetequal(get_jumps(sysa), get_jumps(sysb)) &&
1125+
issetequal(get_constraints(sysa), get_constraints(sysb)) &&
1126+
issetequal(get_costs(sysa), get_costs(sysb)) &&
1127+
isequal(get_consolidate(sysa), get_consolidate(sysb)) &&
1128+
issetequal(get_unknowns(sysa), get_unknowns(sysb)) &&
1129+
issetequal(get_ps(sysa), get_ps(sysb)) &&
1130+
issetequal(get_brownians(sysa), get_brownians(sysb)) &&
1131+
issetequal(get_observed(sysa), get_observed(sysb)) &&
1132+
issetequal(get_parameter_dependencies(sysa), get_parameter_dependencies(sysb)) &&
1133+
isequal(get_description(sysa), get_description(sysb)) &&
1134+
isequal(get_defaults(sysa), get_defaults(sysb)) &&
1135+
isequal(get_guesses(sysa), get_guesses(sysb)) &&
1136+
issetequal(get_initialization_eqs(sysa), get_initialization_eqs(sysb)) &&
1137+
issetequal(get_continuous_events(sysa), get_continuous_events(sysb)) &&
1138+
issetequal(get_discrete_events(sysa), get_discrete_events(sysb)) &&
1139+
isequal(get_connector_type(sysa), get_connector_type(sysb)) &&
1140+
isequal(get_assertions(sysa), get_assertions(sysb)) &&
1141+
isequal(get_metadata(sysa), get_metadata(sysb)) &&
1142+
isequal(get_is_dde(sysa), get_is_dde(sysb)) &&
1143+
issetequal(get_tstops(sysa), get_tstops(sysb)) &&
1144+
safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) &&
1145+
isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) &&
1146+
isequal(get_is_discrete(sysa), get_is_discrete(sysb)) &&
1147+
isequal(get_isscheduled(sysa), get_isscheduled(sysb))
1148+
end

test/serialization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ str = String(take!(io))
2828

2929
sys = include_string(@__MODULE__, str)
3030
rc2 = expand_connections(rc_model)
31+
@test isapprox(sys, rc2)
3132
@test issetequal(equations(sys), equations(rc2))
3233
@test issetequal(unknowns(sys), unknowns(rc2))
3334
@test issetequal(parameters(sys), parameters(rc2))

0 commit comments

Comments
 (0)