Skip to content

Commit 70d876b

Browse files
authored
Merge pull request #518 from SciML/sm/symutils3
WIP: Unify the Term definition between MTK and SymbolicUtils
2 parents 35af497 + d5f81cd commit 70d876b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+989
-1050
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3434
[compat]
3535
ArrayInterface = "2.8"
3636
DataStructures = "0.17, 0.18"
37-
DiffEqBase = "6.48"
37+
DiffEqBase = "6.48.1"
3838
DiffEqJump = "6.7.5"
3939
DiffRules = "0.1, 1.0"
4040
DocStringExtensions = "0.7, 0.8"
@@ -50,7 +50,7 @@ RuntimeGeneratedFunctions = "0.4"
5050
SafeTestsets = "0.0.1"
5151
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
5252
StaticArrays = "0.10, 0.11, 0.12"
53-
SymbolicUtils = "0.5"
53+
SymbolicUtils = "0.6"
5454
TreeViews = "0.3"
5555
UnPack = "0.1, 1.0"
5656
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
2121
using RecursiveArrayTools
2222

2323
import SymbolicUtils
24-
import SymbolicUtils: to_symbolic, FnType, @rule, Rewriters, Term
24+
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
2525

2626
using LinearAlgebra: LU, BlasInt
2727

@@ -31,12 +31,99 @@ import TreeViews
3131

3232
using Requires
3333

34+
export Num, Variable
3435
"""
3536
$(TYPEDEF)
3637
37-
Base type for a symbolic expression.
38+
Wrap anything in a type that is a subtype of Real
3839
"""
39-
abstract type Expression <: Number end
40+
struct Num <: Real
41+
val
42+
end
43+
44+
const show_numwrap = Ref(false)
45+
46+
Num(x::Num) = x # ideally this should never be called
47+
(n::Num)(args...) = Num(value(n)(map(value,args)...))
48+
value(x) = x
49+
value(x::Num) = x.val
50+
51+
52+
using SymbolicUtils: to_symbolic
53+
SymbolicUtils.to_symbolic(n::Num) = value(n)
54+
SymbolicUtils.@number_methods(Num,
55+
Num(f(value(a))),
56+
Num(f(value(a), value(b))))
57+
58+
SymbolicUtils.simplify(n::Num; kw...) = Num(SymbolicUtils.simplify(value(n); kw...))
59+
60+
SymbolicUtils.symtype(n::Num) = symtype(n.val)
61+
62+
function Base.iszero(x::Num)
63+
_x = SymbolicUtils.to_mpoly(value(x))[1]
64+
return (_x isa Number || _x isa SymbolicUtils.MPoly) && iszero(_x)
65+
end
66+
67+
import SymbolicUtils: <ₑ, Symbolic, Term, operation, arguments
68+
69+
Base.show(io::IO, n::Num) = show_numwrap[] ? print(io, :(Num($(value(n))))) : Base.show(io, value(n))
70+
71+
Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num
72+
Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num
73+
Base.getproperty(t::Term, f::Symbol) = f === :op ? operation(t) : f === :args ? arguments(t) : getfield(t, f)
74+
<(s::Num, x) = value(s) <value(x)
75+
<(s, x::Num) = value(s) <value(x)
76+
<(s::Num, x::Num) = value(s) <value(x)
77+
78+
for T in (Integer, Rational)
79+
@eval Base.:(^)(n::Num, i::$T) = Num(Term{symtype(n)}(^, [value(n),i]))
80+
end
81+
82+
macro num_method(f, expr, Ts=nothing)
83+
if Ts === nothing
84+
Ts = [Any]
85+
else
86+
@assert Ts.head == :tuple
87+
# e.g. a tuple or vector
88+
Ts = Ts.args
89+
end
90+
91+
ms = [quote
92+
$f(a::$T, b::$Num) = $expr
93+
$f(a::$Num, b::$T) = $expr
94+
end for T in Ts]
95+
quote
96+
$f(a::$Num, b::$Num) = $expr
97+
$(ms...)
98+
end |> esc
99+
end
100+
101+
"""
102+
tosymbolic(a::Union{Sym,Num}) -> Sym{Real}
103+
tosymbolic(a::T) -> T
104+
"""
105+
tosymbolic(a::Num) = tosymbolic(value(a))
106+
tosymbolic(a::Sym) = tovar(a)
107+
tosymbolic(a) = a
108+
@num_method Base.isless isless(tosymbolic(a), tosymbolic(b)) (Real,)
109+
@num_method Base.:(<) (tosymbolic(a) < tosymbolic(b)) (Real,)
110+
@num_method Base.:(<=) (tosymbolic(a) <= tosymbolic(b)) (Real,)
111+
@num_method Base.:(>) (tosymbolic(a) > tosymbolic(b)) (Real,)
112+
@num_method Base.:(>=) (tosymbolic(a) >= tosymbolic(b)) (Real,)
113+
@num_method Base.isequal isequal(tosymbolic(a), tosymbolic(b)) (Number, Symbolic)
114+
@num_method Base.:(==) tosymbolic(a) == tosymbolic(b) (Number,)
115+
116+
Base.hash(x::Num, h::UInt) = hash(value(x), h)
117+
118+
Base.convert(::Type{Num}, x::Symbolic{<:Number}) = Num(x)
119+
Base.convert(::Type{Num}, x::Number) = Num(x)
120+
Base.convert(::Type{Num}, x::Num) = x
121+
122+
Base.convert(::Type{<:Array{Num}}, x::AbstractArray) = map(Num, x)
123+
Base.convert(::Type{<:Array{Num}}, x::AbstractArray{Num}) = x
124+
Base.convert(::Type{Sym}, x::Num) = value(x) isa Sym ? value(x) : error("cannot convert $x to Sym")
125+
126+
LinearAlgebra.lu(x::Array{Num}; kw...) = lu(x, Val{false}(); kw...)
40127

41128
"""
42129
$(TYPEDEF)
@@ -46,14 +133,6 @@ TODO
46133
abstract type AbstractSystem end
47134
abstract type AbstractODESystem <: AbstractSystem end
48135

49-
Base.promote_rule(::Type{<:Number},::Type{<:Expression}) = Expression
50-
Base.zero(::Type{<:Expression}) = Constant(0)
51-
Base.zero(::Expression) = Constant(0)
52-
Base.one(::Type{<:Expression}) = Constant(1)
53-
Base.one(::Expression) = Constant(1)
54-
Base.oneunit(::Expression) = Constant(1)
55-
Base.oneunit(::Type{<:Expression}) = Constant(1)
56-
57136
"""
58137
$(TYPEDSIGNATURES)
59138
@@ -77,28 +156,15 @@ function parameters end
77156

78157
include("variables.jl")
79158
include("context_dsl.jl")
80-
include("operations.jl")
81159
include("differentials.jl")
82160

83-
function Base.convert(::Type{Variable},x::Operation)
84-
if x.op isa Variable
85-
x.op
86-
elseif x.op isa Differential
87-
var = x.args[1].op
88-
rename(var,Symbol(var.name,,x.op.x))
89-
else
90-
throw(error("This Operation is not a Variable"))
91-
end
92-
end
93-
94161
include("equations.jl")
95-
include("function_registration.jl")
96-
include("simplify.jl")
97162
include("utils.jl")
98163
include("linearity.jl")
99164
include("solve.jl")
100165
include("direct.jl")
101166
include("domains.jl")
167+
include("register_function.jl")
102168

103169
include("systems/abstractsystem.jl")
104170

@@ -145,8 +211,8 @@ export Reaction, ReactionSystem, ismassaction, oderatelaw, jumpratelaw
145211
export Differential, expand_derivatives, @derivatives
146212
export IntervalDomain, ProductDomain, , CircleDomain
147213
export Equation, ConstrainedEquation
148-
export Operation, Expression, Variable
149-
export independent_variable, states, controls, parameters, equations, pins, observed
214+
export Term, Sym
215+
export independent_variable, states, parameters, equations, controls, pins, observed
150216

151217
export calculate_jacobian, generate_jacobian, generate_function
152218
export calculate_tgrad, generate_tgrad
@@ -160,7 +226,7 @@ export BipartiteGraph, equation_dependencies, variable_dependencies
160226
export eqeq_dependencies, varvar_dependencies
161227
export asgraph, asdigraph
162228

163-
export simplified_expr, rename, get_variables
229+
export toexpr, rename, get_variables
164230
export simplify, substitute
165231
export build_function
166232
export @register

0 commit comments

Comments
 (0)