Skip to content

Commit f6802a5

Browse files
authored
Merge pull request #765 from SciML/myb/pretty_print
Pretty print and some fixes
2 parents 755ebdf + 3329dfe commit f6802a5

12 files changed

+85
-44
lines changed

src/ModelingToolkit.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ end
152152
tosymbolic(a::Num) = tosymbolic(value(a))
153153
tosymbolic(a::Sym) = tovar(a)
154154
tosymbolic(a) = a
155-
@num_method Base.isless isless(tosymbolic(a), tosymbolic(b)) (Real,)
156-
@num_method Base.:(<) (tosymbolic(a) < tosymbolic(b)) (Real,)
157-
@num_method Base.:(<=) (tosymbolic(a) <= tosymbolic(b)) (Real,)
158-
@num_method Base.:(>) (tosymbolic(a) > tosymbolic(b)) (Real,)
159-
@num_method Base.:(>=) (tosymbolic(a) >= tosymbolic(b)) (Real,)
155+
@num_method Base.isless (val = isless(tosymbolic(a), tosymbolic(b)); val isa Bool ? val : Num(val)) (Real,)
156+
@num_method Base.:(<) (val = tosymbolic(a) < tosymbolic(b) ; val isa Bool ? val : Num(val)) (Real,)
157+
@num_method Base.:(<=) (val = tosymbolic(a) <= tosymbolic(b) ; val isa Bool ? val : Num(val)) (Real,)
158+
@num_method Base.:(>) (val = tosymbolic(a) > tosymbolic(b) ; val isa Bool ? val : Num(val)) (Real,)
159+
@num_method Base.:(>=) (val = tosymbolic(a) >= tosymbolic(b) ; val isa Bool ? val : Num(val)) (Real,)
160+
@num_method Base.:(==) (val = tosymbolic(a) == tosymbolic(b) ; val isa Bool ? val : Num(val)) (AbstractFloat,Number)
160161
@num_method Base.isequal isequal(tosymbolic(a), tosymbolic(b)) (AbstractFloat, Number, Symbolic)
161-
@num_method Base.:(==) tosymbolic(a) == tosymbolic(b) (AbstractFloat,Number)
162162

163163
Base.hash(x::Num, h::UInt) = hash(value(x), h)
164164

src/equations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct Equation
1414
end
1515
Base.:(==)(a::Equation, b::Equation) = all(isequal.((a.lhs, a.rhs), (b.lhs, b.rhs)))
1616
Base.hash(a::Equation, salt::UInt) = hash(a.lhs, hash(a.rhs, salt))
17+
1718
Base.show(io::IO, eq::Equation) = print(io, eq.lhs, " ~ ", eq.rhs)
1819

1920
SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...)

src/extra_functions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@register Base.getindex(x,i::Integer)
2-
@register Base.getindex(x,i)
1+
@register Base.getindex(x,i::Integer) false
2+
@register Base.getindex(x,i) false
33
@register Base.binomial(n,k)
44

55
@register Base.signbit(x)
@@ -22,7 +22,7 @@ function ModelingToolkit.derivative(::typeof(max), args::NTuple{2,Any}, ::Val{2}
2222
IfElse.ifelse(x > y, zero(y), one(y))
2323
end
2424

25-
@register IfElse.ifelse(x,y,z::Any)
25+
IfElse.ifelse(x::Num,y,z) = Num(Term{Real}(IfElse.ifelse, [value(x), value(y), value(z)]))
2626
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{1}) = 0
2727
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{2}) = IfElse.ifelse(args[1],1,0)
2828
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{3}) = IfElse.ifelse(args[1],0,1)
@@ -36,8 +36,8 @@ ModelingToolkit.@register Distributions.cdf(dist,x)
3636
ModelingToolkit.@register Distributions.logcdf(dist,x)
3737
ModelingToolkit.@register Distributions.quantile(dist,x)
3838

39-
ModelingToolkit.@register Distributions.Uniform(mu,sigma)
40-
ModelingToolkit.@register Distributions.Normal(mu,sigma)
39+
ModelingToolkit.@register Distributions.Uniform(mu,sigma) false
40+
ModelingToolkit.@register Distributions.Normal(mu,sigma) false
4141

4242
@register (x::Num, y::AbstractArray)
4343
@register (x, y)
@@ -74,4 +74,4 @@ function LinearAlgebra.det(A::AbstractMatrix{<:Num}; laplace=true)
7474
end
7575
return det(lu(A; check = false))
7676
end
77-
end
77+
end

src/register_function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
@register(expr, Ts = [Num, Symbolic, Real])
2+
@register(expr, define_promotion, Ts = [Num, Symbolic, Real])
33
44
Overload approperate methods such that ModelingToolkit can stop tracing into the
55
registered function.
@@ -11,7 +11,7 @@ registered function.
1111
@register hoo(x, y)::Int # `hoo` returns `Int`
1212
```
1313
"""
14-
macro register(expr, Ts = [Num, Symbolic, Real])
14+
macro register(expr, define_promotion = true, Ts = [Num, Symbolic, Real])
1515
if expr.head === :(::)
1616
ret_type = expr.args[2]
1717
expr = expr.args[1]
@@ -48,7 +48,7 @@ macro register(expr, Ts = [Num, Symbolic, Real])
4848
push!(
4949
ex.args,
5050
quote
51-
if $!($hasmethod($promote_symtype, $Tuple{$typeof($f), $Vararg}))
51+
if $define_promotion
5252
(::$typeof($promote_symtype))(::$typeof($f), args...) = $ret_type
5353
end
5454
end

src/solve.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,23 @@ Assumes `length(eqs) == length(vars)`
8282
8383
Currently only works if all equations are linear.
8484
"""
85-
function solve_for(eqs, vars)
85+
function solve_for(eqs, vars; simplify=true)
8686
A, b = A_b(eqs, vars)
87-
_solve(A, b)
87+
#TODO: we need to make sure that `solve_for(eqs, vars)` contains no `vars`
88+
_solve(A, b, simplify)
8889
end
8990

90-
function _solve(A::AbstractMatrix, b::AbstractArray)
91+
function _solve(A::AbstractMatrix, b::AbstractArray, do_simplify)
9192
A = SymbolicUtils.simplify.(Num.(A), polynorm=true)
9293
b = SymbolicUtils.simplify.(Num.(b), polynorm=true)
93-
value.(SymbolicUtils.simplify.(sym_lu(A) \ b))
94+
sol = value.(sym_lu(A) \ b)
95+
do_simplify ? SymbolicUtils.simplify.(sol, polynorm=true) : sol
96+
end
97+
98+
function _solve(a, b, do_simplify)
99+
sol = value(b/a)
100+
do_simplify ? SymbolicUtils.simplify(sol, polynorm=true) : sol
94101
end
95-
_solve(a, b) = value(SymbolicUtils.simplify(b/a, polynorm=true))
96102

97103
# ldiv below
98104

src/systems/abstractsystem.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,53 @@ function (f::AbstractSysToExpr)(O)
316316
end
317317
return build_expr(:call, Any[operation(O); f.(arguments(O))])
318318
end
319+
320+
function Base.show(io::IO, sys::AbstractSystem)
321+
eqs = equations(sys)
322+
Base.printstyled(io, "Equations ($(length(eqs))):\n"; bold=true)
323+
Base.print_matrix(io, eqs)
324+
println(io)
325+
326+
rows = first(displaysize(io)) ÷ 3
327+
limit = get(io, :limit, false)
328+
329+
vars = states(sys); nvars = length(vars)
330+
Base.printstyled(io, "States ($nvars):"; bold=true)
331+
nrows = min(nvars, limit ? rows : nvars)
332+
limited = nrows < length(vars)
333+
d_u0 = default_u0(sys)
334+
for i in 1:nrows
335+
s = vars[i]
336+
print(io, "\n ", s)
337+
338+
val = get(d_u0, s, nothing)
339+
if val !== nothing
340+
print(io, " [defaults to $val]")
341+
end
342+
end
343+
limited && print(io, "\n")
344+
println(io)
345+
346+
vars = parameters(sys); nvars = length(vars)
347+
Base.printstyled(io, "Parameters ($nvars):"; bold=true)
348+
nrows = min(nvars, limit ? rows : nvars)
349+
limited = nrows < length(vars)
350+
d_p = default_p(sys)
351+
for i in 1:nrows
352+
s = vars[i]
353+
print(io, "\n ", s)
354+
355+
val = get(d_p, s, nothing)
356+
if val !== nothing
357+
print(io, " [defaults to $val]")
358+
end
359+
end
360+
limited && print(io, "\n")
361+
362+
s = get_structure(sys)
363+
if s !== nothing
364+
Base.printstyled(io, "\nIncidence matrix:"; color=:magenta)
365+
show(io, incidence_matrix(s.graph, Num(Sym{Real}(:×))))
366+
end
367+
return nothing
368+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ function ODESystem(eqs, iv=nothing; kwargs...)
138138
end
139139
end
140140
end
141+
iv = value(iv)
141142
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
142143
for eq in eqs
143144
collect_vars!(allstates, ps, eq.lhs, iv)

src/systems/systemstructure.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ function init_graph(sys)
9595
algvar_offset = 2dxvar_offset
9696

9797
fullvars = [xvars; dxvars; algvars]
98-
sys = reordersys(sys, dxvar_offset, fullvars)
9998
eqs = equations(sys)
10099
idxmap = Dict(fullvars .=> 1:length(fullvars))
101100
graph = BipartiteGraph(length(eqs), length(fullvars))
@@ -122,21 +121,3 @@ function init_graph(sys)
122121
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list
123122
sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph
124123
end
125-
126-
function reordersys(sys, dxvar_offset, fullvars)
127-
eqs = equations(sys)
128-
neweqs = similar(eqs, Equation)
129-
eqidxmap = Dict(@view(fullvars[dxvar_offset+1:2dxvar_offset]) .=> (1:dxvar_offset))
130-
varidxmap = Dict([@view(fullvars[1:dxvar_offset]); @view(fullvars[2dxvar_offset+1:end])] .=> (1:length(fullvars)-dxvar_offset))
131-
algidx = dxvar_offset
132-
for eq in eqs
133-
if isdiffeq(eq)
134-
neweqs[eqidxmap[eq.lhs]] = eq
135-
else
136-
neweqs[algidx+=1] = eq
137-
end
138-
end
139-
sts = states(sys)
140-
@set! sys.eqs = neweqs
141-
@set! sys.states = sts[map(s->varidxmap[s], sts)]
142-
end

test/operation_overloads.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ F = lu(X)
2525
R = simplify.(F.L * F.U - X[F.p, :], polynorm=true)
2626
@test iszero(R)
2727
@test simplify.(F \ X) == I
28-
@test ModelingToolkit._solve(X, X) == I
28+
@test ModelingToolkit._solve(X, X, true) == I
2929
inv(X)
3030
qr(X)
3131

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ reduced_system = alias_elimination(connected; conservative=false)
9797
]) |> isempty
9898

9999
reduced_eqs = [
100+
0 ~ a + lorenz1.x - (lorenz2.y)
100101
D(lorenz1.x) ~ lorenz2.x + lorenz2.y + lorenz1.σ*((lorenz1.y) - (lorenz1.x)) - (lorenz2.z)
101102
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ - (lorenz1.z)) - ((lorenz1.x) + (lorenz1.y) - (lorenz1.z))
102103
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - (lorenz1.β*(lorenz1.z))
103104
D(lorenz2.x) ~ lorenz1.x + lorenz1.y + lorenz2.σ*((lorenz2.y) - (lorenz2.x)) - (lorenz1.z)
104105
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ - (lorenz2.z)) - ((lorenz2.x) + (lorenz2.y) - (lorenz2.z))
105106
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - (lorenz2.β*(lorenz2.z))
106-
0 ~ a + lorenz1.x - (lorenz2.y)
107107
]
108108

109109
test_equal.(equations(reduced_system), reduced_eqs)

0 commit comments

Comments
 (0)