Skip to content

Commit b2b337e

Browse files
committed
Add setproperties overload
Fix #773
1 parent ae029d0 commit b2b337e

File tree

5 files changed

+34
-6
lines changed

5 files changed

+34
-6
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "5.6.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1011
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
@@ -37,6 +38,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3738

3839
[compat]
3940
ArrayInterface = "2.8, 3.0"
41+
ConstructionBase = "1"
4042
DataStructures = "0.17, 0.18"
4143
DiffEqBase = "6.54.0"
4244
DiffEqJump = "6.7.5"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
66
using Latexify, Unitful, ArrayInterface
77
using MacroTools
88
using UnPack: @unpack
9-
using Setfield
9+
using Setfield, ConstructionBase
1010
using DiffEqJump
1111
using DataStructures
1212
using SpecialFunctions, NaNMath

src/systems/abstractsystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,25 @@ for prop in [:eqs, :iv, :states, :ps, :default_p, :default_u0, :observed, :tgrad
143143
end
144144
end
145145

146+
Setfield.get(obj, l::Setfield.PropertyLens{field}) where {field} = getfield(obj, field)
147+
@generated function ConstructionBase.setproperties(obj::AbstractSystem, patch::NamedTuple)
148+
if issubset(fieldnames(patch), fieldnames(obj))
149+
args = map(fieldnames(obj)) do fn
150+
if fn in fieldnames(patch)
151+
:(patch.$fn)
152+
else
153+
:(getfield(obj, $(Meta.quot(fn))))
154+
end
155+
end
156+
return Expr(:block,
157+
Expr(:meta, :inline),
158+
Expr(:call,:(constructorof($obj)), args...)
159+
)
160+
else
161+
error("This should never happen. Trying to set $(typeof(obj)) with $patch.")
162+
end
163+
end
164+
146165
function Base.getproperty(sys::AbstractSystem, name::Symbol)
147166
sysname = nameof(sys)
148167
systems = get_systems(sys)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ Takes a Nth order ODESystem and returns a new ODESystem written in first order
55
form by defining new variables which represent the N-1 derivatives.
66
"""
77
function ode_order_lowering(sys::ODESystem)
8-
eqs_lowered, new_vars = ode_order_lowering(equations(sys), sys.iv, states(sys))
9-
return ODESystem(eqs_lowered, sys.iv, new_vars, sys.ps)
8+
iv = independent_variable(sys)
9+
eqs_lowered, new_vars = ode_order_lowering(equations(sys), iv, states(sys))
10+
return ODESystem(eqs_lowered, iv, new_vars, parameters(sys))
1011
end
1112

1213
function ode_order_lowering(eqs, iv, states)

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,15 @@ function collect_var!(states, parameters, var, iv)
182182
end
183183

184184
# NOTE: equality does not check cached Jacobian
185-
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
186-
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
187-
_eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps)
185+
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
186+
iv1 = independent_variable(sys1)
187+
iv2 = independent_variable(sys2)
188+
isequal(iv1, iv2) &&
189+
_eq_unordered(get_eqs(sys1), get_eqs(sys2)) &&
190+
_eq_unordered(get_states(sys1), get_states(sys2)) &&
191+
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
192+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
193+
end
188194

189195
function flatten(sys::ODESystem)
190196
systems = get_systems(sys)

0 commit comments

Comments
 (0)