Skip to content

Commit f278b79

Browse files
authored
Merge pull request #1236 from SciML/myb/extend
More robust system nesting
2 parents 98ebb0c + cee7c02 commit f278b79

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ by default.
948948
function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameof(sys))
949949
T = SciMLBase.parameterless_type(basesys)
950950
ivs = independent_variables(basesys)
951-
if !(typeof(sys) <: T)
951+
if !(sys isa T)
952952
if length(ivs) == 0
953953
sys = convert_system(T, sys)
954954
elseif length(ivs) == 1
@@ -958,11 +958,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameo
958958
end
959959
end
960960

961-
eqs = union(equations(basesys), equations(sys))
962-
sts = union(states(basesys), states(sys))
963-
ps = union(parameters(basesys), parameters(sys))
964-
obs = union(observed(basesys), observed(sys))
965-
defs = merge(defaults(basesys), defaults(sys)) # prefer `sys`
961+
eqs = union(get_eqs(basesys), get_eqs(sys))
962+
sts = union(get_states(basesys), get_states(sys))
963+
ps = union(get_ps(basesys), get_ps(sys))
964+
obs = union(get_observed(basesys), get_observed(sys))
965+
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
966966
syss = union(get_systems(basesys), get_systems(sys))
967967

968968
if length(ivs) == 0
@@ -984,7 +984,7 @@ function compose(sys::AbstractSystem, systems::AbstractArray{<:AbstractSystem};
984984
nsys = length(systems)
985985
nsys >= 1 || throw(ArgumentError("There must be at least 1 subsystem. Got $nsys subsystems."))
986986
@set! sys.name = name
987-
@set! sys.systems = systems
987+
@set! sys.systems = [get_systems(sys); systems]
988988
return sys
989989
end
990990
compose(syss::AbstractSystem...; name=nameof(first(syss))) = compose(first(syss), collect(syss[2:end]); name=name)

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ end
187187

188188
# NOTE: equality does not check cached Jacobian
189189
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
190+
sys1 === sys2 && return true
190191
iv1 = get_iv(sys1)
191192
iv2 = get_iv(sys2)
192193
isequal(iv1, iv2) &&

test/components.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,11 @@ sol = solve(prob, Rodas4())
4747

4848
prob = ODAEProblem(sys, u0, (0, 10.0))
4949
sol = solve(prob, Tsit5())
50+
51+
@variables t x1(t) x2(t) x3(t) x4(t)
52+
D = Differential(t)
53+
@named sys1_inner = ODESystem([D(x1) ~ x1], t)
54+
@named sys1_partial = compose(ODESystem([D(x2) ~ x2], t; name=:foo), sys1_inner)
55+
@named sys1 = extend(ODESystem([D(x3) ~ x3], t; name=:foo), sys1_partial)
56+
@named sys2 = compose(ODESystem([D(x4) ~ x4], t; name=:foo), sys1)
57+
@test_nowarn sys2.sys1.sys1_inner.x1 # test the correct nesting

0 commit comments

Comments
 (0)