Skip to content

Commit fff670e

Browse files
authored
Merge pull request #2020 from SciML/myb/uni
Unityper update
2 parents 9dfdec3 + 898e89e commit fff670e

15 files changed

+108
-109
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ SimpleNonlinearSolve = "0.1.0"
7878
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7979
StaticArrays = "0.10, 0.11, 0.12, 1.0"
8080
SymbolicIndexingInterface = "0.1, 0.2"
81-
SymbolicUtils = "0.19"
82-
Symbolics = "4.9"
81+
SymbolicUtils = "1.0"
82+
Symbolics = "5.0"
8383
UnPack = "0.1, 1.0"
8484
Unitful = "1.1"
8585
julia = "1.6"

src/ModelingToolkit.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ import SymbolicIndexingInterface: independent_variables, states, parameters
3939
export independent_variables, states, parameters
4040
import SymbolicUtils
4141
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
42-
Symbolic, Term, Add, Mul, Pow, Sym, FnType,
43-
@rule, Rewriters, substitute, metadata
42+
Symbolic, isadd, ismul, ispow, issym, FnType,
43+
@rule, Rewriters, substitute, metadata, BasicSymbolic,
44+
Sym, Term
4445
using SymbolicUtils.Code
4546
import SymbolicUtils.Code: toexpr
4647
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint

src/clock.jl

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,14 @@ const InferredDomain = Union{Inferred, InferredDiscrete}
1313
Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain
1414

1515
"""
16-
is_continuous_domain(x::Sym)
17-
18-
Determine if variable `x` is a continuous-time variable.
19-
"""
20-
is_continuous_domain(x::Sym) = getmetadata(x, TimeDomain, false) isa Continuous
21-
22-
"""
23-
is_discrete_domain(x::Sym)
24-
25-
Determine if variable `x` is a discrete-time variable.
16+
is_continuous_domain(x)
17+
true if `x` contains only continuous-domain signals.
18+
See also [`has_continuous_domain`](@ref)
2619
"""
27-
is_discrete_domain(x::Sym) = getmetadata(x, TimeDomain, false) isa Discrete
28-
29-
# is_discrete_domain(x::Sym) = isvarkind(Discrete, x)
30-
31-
has_continuous_domain(x::Sym) = is_continuous_domain(x)
32-
has_discrete_domain(x::Sym) = is_discrete_domain(x)
20+
function is_continuous_domain(x)
21+
issym(x) && return getmetadata(x, TimeDomain, false) isa Continuous
22+
!has_discrete_domain(x) && has_continuous_domain(x)
23+
end
3324

3425
function get_time_domain(x)
3526
if istree(x) && operation(x) isa Operator
@@ -41,11 +32,10 @@ end
4132
get_time_domain(x::Num) = get_time_domain(value(x))
4233

4334
"""
44-
has_time_domain(x::Sym)
45-
35+
has_time_domain(x)
4636
Determine if variable `x` has a time-domain attributed to it.
4737
"""
48-
function has_time_domain(x::Union{Sym, Term})
38+
function has_time_domain(x::Symbolic)
4939
# getmetadata(x, Continuous, nothing) !== nothing ||
5040
# getmetadata(x, Discrete, nothing) !== nothing
5141
getmetadata(x, TimeDomain, nothing) !== nothing
@@ -64,15 +54,21 @@ end
6454
true if `x` contains discrete signals (`x` may or may not contain continuous-domain signals). `x` may be an expression or equation.
6555
See also [`is_discrete_domain`](@ref)
6656
"""
67-
has_discrete_domain(x) = hasshift(x) || hassample(x) || hashold(x)
57+
function has_discrete_domain(x)
58+
issym(x) && return is_discrete_domain(x)
59+
hasshift(x) || hassample(x) || hashold(x)
60+
end
6861

6962
"""
7063
has_continuous_domain(x)
7164
7265
true if `x` contains continuous signals (`x` may or may not contain discrete-domain signals). `x` may be an expression or equation.
7366
See also [`is_continuous_domain`](@ref)
7467
"""
75-
has_continuous_domain(x) = hasderiv(x) || hasdiff(x) || hassample(x) || hashold(x)
68+
function has_continuous_domain(x)
69+
issym(x) && return is_continuous_domain(x)
70+
hasderiv(x) || hasdiff(x) || hassample(x) || hashold(x)
71+
end
7672

7773
"""
7874
is_hybrid_domain(x)
@@ -87,15 +83,10 @@ is_hybrid_domain(x) = has_discrete_domain(x) && has_continuous_domain(x)
8783
true if `x` contains only discrete-domain signals.
8884
See also [`has_discrete_domain`](@ref)
8985
"""
90-
is_discrete_domain(x) = has_discrete_domain(x) && !has_continuous_domain(x)
91-
92-
"""
93-
is_continuous_domain(x)
94-
95-
true if `x` contains only continuous-domain signals.
96-
See also [`has_continuous_domain`](@ref)
97-
"""
98-
is_continuous_domain(x) = !has_discrete_domain(x) && has_continuous_domain(x)
86+
function is_discrete_domain(x)
87+
issym(x) && return getmetadata(x, TimeDomain, false) isa Discrete
88+
!has_discrete_domain(x) && has_continuous_domain(x)
89+
end
9990

10091
struct ClockInferenceException <: Exception
10192
msg::Any
@@ -122,3 +113,4 @@ struct Clock <: AbstractClock
122113
end
123114

124115
sampletime(c) = isdefined(c, :dt) ? c.dt : nothing
116+
Base.:(==)(c1::Clock, c2::Clock) = isequal(c1.t, c2.t) && c1.dt == c2.dt

src/constants.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ function isconstant(x)
1111
end
1212

1313
"""
14-
toconstant(s::Sym)
14+
toconstant(s)
1515
1616
Maps the parameter to a constant. The parameter must have a default.
1717
"""
18-
function toconstant(s::Sym)
18+
function toconstant(s)
1919
hasmetadata(s, Symbolics.VariableDefaultValue) ||
2020
throw(ArgumentError("Constant `$(s)` must be assigned a default value."))
2121
setmetadata(s, MTKConstantCtx, true)

src/inputoutput.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ function outputs(sys)
1919
lhss = [eq.lhs for eq in o]
2020
unique([filter(isoutput, states(sys))
2121
filter(isoutput, parameters(sys))
22-
filter(x -> x isa Term && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
23-
filter(x -> x isa Term && isoutput(x), lhss)])
22+
filter(x -> istree(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
23+
filter(x -> istree(x) && isoutput(x), lhss)])
2424
end
2525

2626
"""

src/parameters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function isparameter(x)
2727
end
2828

2929
"""
30-
toparam(s::Sym)
30+
toparam(s)
3131
3232
Maps the variable to a paramter.
3333
"""
@@ -43,7 +43,7 @@ end
4343
toparam(s::Num) = wrap(toparam(value(s)))
4444

4545
"""
46-
tovar(s::Sym)
46+
tovar(s)
4747
4848
Maps the variable to a state.
4949
"""

src/systems/abstractsystem.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -419,25 +419,27 @@ function renamespace(sys, x)
419419
sys === nothing && return x
420420
x = unwrap(x)
421421
if x isa Symbolic
422+
T = typeof(x)
422423
if istree(x) && operation(x) isa Operator
423-
return similarterm(x, operation(x), Any[renamespace(sys, only(arguments(x)))])
424+
return similarterm(x, operation(x),
425+
Any[renamespace(sys, only(arguments(x)))])::T
424426
end
425427
let scope = getmetadata(x, SymScope, LocalScope())
426428
if scope isa LocalScope
427-
rename(x, renamespace(getname(sys), getname(x)))
429+
rename(x, renamespace(getname(sys), getname(x)))::T
428430
elseif scope isa ParentScope
429-
setmetadata(x, SymScope, scope.parent)
431+
setmetadata(x, SymScope, scope.parent)::T
430432
elseif scope isa DelayParentScope
431433
if scope.N > 0
432434
x = setmetadata(x, SymScope,
433435
DelayParentScope(scope.parent, scope.N - 1))
434-
rename(x, renamespace(getname(sys), getname(x)))
436+
rename(x, renamespace(getname(sys), getname(x)))::T
435437
else
436438
#rename(x, renamespace(getname(sys), getname(x)))
437-
setmetadata(x, SymScope, scope.parent)
439+
setmetadata(x, SymScope, scope.parent)::T
438440
end
439441
else # GlobalScope
440-
x
442+
x::T
441443
end
442444
end
443445
elseif x isa AbstractSystem
@@ -453,9 +455,8 @@ namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
453455

454456
function namespace_defaults(sys)
455457
defs = defaults(sys)
456-
Dict((isparameter(k) ? parameters(sys, k) : states(sys, k)) => namespace_expr(defs[k],
457-
sys)
458-
for k in keys(defs))
458+
Dict((isparameter(k) ? parameters(sys, k) : states(sys, k)) => namespace_expr(v, sys)
459+
for (k, v) in pairs(defs))
459460
end
460461

461462
function namespace_equations(sys::AbstractSystem)
@@ -484,14 +485,19 @@ function namespace_expr(O, sys, n = nameof(sys))
484485
elseif isvariable(O)
485486
renamespace(n, O)
486487
elseif istree(O)
487-
renamed = map(a -> namespace_expr(a, sys, n), arguments(O))
488+
T = typeof(O)
488489
if symtype(operation(O)) <: FnType
489-
renamespace(n, O)
490+
renamespace(n, O)::T
490491
else
491-
similarterm(O, operation(O), renamed)
492+
renamed = let sys = sys, n = n, T = T
493+
map(a -> namespace_expr(a, sys, n)::Any, arguments(O))
494+
end
495+
similarterm(O, operation(O), renamed)::T
492496
end
493497
elseif O isa Array
494-
map(o -> namespace_expr(o, sys, n), O)
498+
let sys = sys, n = n
499+
map(o -> namespace_expr(o, sys, n), O)
500+
end
495501
else
496502
O
497503
end
@@ -602,7 +608,7 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
602608
#
603609
# This is done by just making `x` the argument of the function.
604610
if istree(x) &&
605-
operation(x) isa Sym &&
611+
issym(operation(x)) &&
606612
!(length(arguments(x)) == 1 && isequal(arguments(x)[1], get_iv(sys)))
607613
return operation(x)
608614
end
@@ -646,7 +652,7 @@ AbstractSysToExpr(sys) = AbstractSysToExpr(sys, states(sys))
646652
function (f::AbstractSysToExpr)(O)
647653
!istree(O) && return toexpr(O)
648654
any(isequal(O), f.states) && return nameof(operation(O)) # variables
649-
if isa(operation(O), Sym)
655+
if issym(operation(O))
650656
return build_expr(:call, Any[nameof(operation(O)); f.(arguments(O))])
651657
end
652658
return build_expr(:call, Any[operation(O); f.(arguments(O))])
@@ -675,7 +681,7 @@ end
675681
function round_trip_expr(t, var2name)
676682
name = get(var2name, t, nothing)
677683
name !== nothing && return name
678-
t isa Sym && return nameof(t)
684+
issym(t) && return nameof(t)
679685
istree(t) || return t
680686
f = round_trip_expr(operation(t), var2name)
681687
args = map(Base.Fix2(round_trip_expr, var2name), arguments(t))

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify = false)
200200
M = zeros(length(eqs), length(eqs))
201201
state2idx = Dict(s => i for (i, s) in enumerate(dvs))
202202
for (i, eq) in enumerate(eqs)
203-
if eq.lhs isa Term && operation(eq.lhs) isa Differential
203+
if istree(eq.lhs) && operation(eq.lhs) isa Differential
204204
st = var_from_nested_derivative(eq.lhs)[1]
205205
j = state2idx[st]
206206
M[i, j] = 1

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct ODESystem <: AbstractODESystem
3131
"""The ODEs defining the system."""
3232
eqs::Vector{Equation}
3333
"""Independent variable."""
34-
iv::Sym
34+
iv::BasicSymbolic{Real}
3535
"""
3636
Dependent (state) variables. Must not contain the independent variable.
3737
@@ -422,7 +422,7 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys))
422422
newsts[i] = s
423423
continue
424424
end
425-
ns = similarterm(s, operation(s), (t,); metadata = SymbolicUtils.metadata(s))
425+
ns = similarterm(s, operation(s), Any[t]; metadata = SymbolicUtils.metadata(s))
426426
newsts[i] = ns
427427
varmap[s] = ns
428428
else

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct SDESystem <: AbstractODESystem
3737
"""The expressions defining the diffusion term."""
3838
noiseeqs::AbstractArray
3939
"""Independent variable."""
40-
iv::Sym
40+
iv::BasicSymbolic{Real}
4141
"""Dependent (state) variables. Must not contain the independent variable."""
4242
states::Vector
4343
"""Parameter variables. Must not contain the independent variable."""

0 commit comments

Comments
 (0)