Skip to content

Commit d7276db

Browse files
authored
Merge pull request #978 from SciML/myb/print
Round-trip system printing
2 parents d6bb4b6 + fbd580b commit d7276db

File tree

4 files changed

+107
-3
lines changed

4 files changed

+107
-3
lines changed

src/systems/abstractsystem.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,98 @@ function (f::AbstractSysToExpr)(O)
411411
return build_expr(:call, Any[operation(O); f.(arguments(O))])
412412
end
413413

414+
###
415+
### System utils
416+
###
417+
function push_vars!(stmt, name, typ, vars)
418+
isempty(vars) && return
419+
vars_expr = Expr(:macrocall, typ, nothing)
420+
for s in vars
421+
if istree(s)
422+
f = nameof(operation(s))
423+
args = arguments(s)
424+
ex = :($f($(args...)))
425+
else
426+
ex = nameof(s)
427+
end
428+
push!(vars_expr.args, ex)
429+
end
430+
push!(stmt, :($name = $collect($vars_expr)))
431+
return
432+
end
433+
434+
function round_trip_expr(t, var2name)
435+
name = get(var2name, t, nothing)
436+
name !== nothing && return name
437+
t isa Sym && return nameof(t)
438+
istree(t) || return t
439+
f = round_trip_expr(operation(t), var2name)
440+
args = map(Base.Fix2(round_trip_expr, var2name), arguments(t))
441+
return :($f($(args...)))
442+
end
443+
round_trip_eq(eq, var2name) = Expr(:call, :~, round_trip_expr(eq.lhs, var2name), round_trip_expr(eq.rhs, var2name))
444+
445+
function push_eqs!(stmt, eqs, var2name)
446+
eqs_name = gensym(:eqs)
447+
eqs_expr = Expr(:vcat)
448+
eqs_blk = Expr(:(=), eqs_name, eqs_expr)
449+
for eq in eqs
450+
push!(eqs_expr.args, round_trip_eq(eq, var2name))
451+
end
452+
453+
push!(stmt, eqs_blk)
454+
return eqs_name
455+
end
456+
457+
function push_defaults!(stmt, defs, var2name)
458+
defs_name = gensym(:defs)
459+
defs_expr = Expr(:call, Dict)
460+
defs_blk = Expr(:(=), defs_name, defs_expr)
461+
for d in defs
462+
n = round_trip_expr(d.first, var2name)
463+
v = round_trip_expr(d.second, var2name)
464+
push!(defs_expr.args, :($(=>)($n, $v)))
465+
end
466+
467+
push!(stmt, defs_blk)
468+
return defs_name
469+
end
470+
471+
function toexpr(sys::AbstractSystem)
472+
sys = flatten(sys)
473+
expr = Expr(:block)
474+
stmt = expr.args
475+
476+
iv = independent_variable(sys)
477+
ivname = gensym(:iv)
478+
if iv !== nothing
479+
push!(stmt, :($ivname = (@variables $(getname(iv)))[1]))
480+
end
481+
482+
stsname = gensym(:sts)
483+
sts = states(sys)
484+
push_vars!(stmt, stsname, Symbol("@variables"), sts)
485+
psname = gensym(:ps)
486+
ps = parameters(sys)
487+
push_vars!(stmt, psname, Symbol("@parameters"), ps)
488+
489+
var2name = Dict{Any,Symbol}()
490+
for v in Iterators.flatten((sts, ps))
491+
var2name[v] = getname(v)
492+
end
493+
494+
eqs_name = push_eqs!(stmt, equations(sys), var2name)
495+
defs_name = push_defaults!(stmt, defaults(sys), var2name)
496+
497+
if sys isa ODESystem
498+
push!(stmt, :($ODESystem($eqs_name, $ivname, $stsname, $psname; defaults=$defs_name)))
499+
elseif sys isa NonlinearSystem
500+
push!(stmt, :($NonlinearSystem($eqs_name, $stsname, $psname; defaults=$defs_name)))
501+
end
502+
503+
striplines(expr) # keeping the line numbers is never helpful
504+
end
505+
414506
function Base.show(io::IO, ::MIME"text/plain", sys::AbstractSystem)
415507
eqs = equations(sys)
416508
if eqs isa AbstractArray
@@ -581,6 +673,10 @@ function check_eqs_u0(eqs, dvs, u0)
581673
return nothing
582674
end
583675

676+
###
677+
### Connectors
678+
###
679+
584680
function with_connection_type(expr)
585681
@assert expr isa Expr && (expr.head == :function || (expr.head == :(=) &&
586682
expr.args[1] isa Expr &&

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,10 @@ function flatten(sys::NonlinearSystem)
308308
)
309309
end
310310
end
311+
312+
function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
313+
_eq_unordered(get_eqs(sys1), get_eqs(sys2)) &&
314+
_eq_unordered(get_states(sys1), get_states(sys2)) &&
315+
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
316+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
317+
end

test/nonlinearsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ end
2121
eqs = [0 ~ σ*(y-x),
2222
0 ~ x*-z)-y,
2323
0 ~ x*y - β*z]
24-
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
24+
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β], defaults = Dict(x => 2))
25+
@test eval(toexpr(ns)) == ns
2526
test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β))
2627
@test begin
2728
f = eval(generate_function(ns, [x,y,z], [σ,ρ,β])[2])

test/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ eqs = [D(x) ~ σ*(y-x),
1717
D(z) ~ x*y - β*z]
1818

1919
ModelingToolkit.toexpr.(eqs)[1]
20-
:(derivative(x(t), t) = σ * (y(t) - x(t))).args
21-
de = ODESystem(eqs)
20+
de = ODESystem(eqs; defaults=Dict(x => 1))
21+
@test eval(toexpr(de)) == de
2222

2323
generate_function(de)
2424

0 commit comments

Comments
 (0)