Skip to content

Commit a11cfa6

Browse files
committed
handle time dependent parameters -- not just names
1 parent 4f2ab4c commit a11cfa6

File tree

6 files changed

+19
-22
lines changed

6 files changed

+19
-22
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import IfElse
1919
using RecursiveArrayTools
2020

2121
import SymbolicUtils
22-
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute
22+
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
2323

2424
using LinearAlgebra: LU, BlasInt
2525

@@ -50,7 +50,7 @@ end
5050
const show_numwrap = Ref(false)
5151

5252
Num(x::Num) = x # ideally this should never be called
53-
(n::Num)(args...) = value(n)(map(value,args)...)
53+
(n::Num)(args...) = Num(value(n)(map(value,args)...))
5454
value(x) = x
5555
value(x::Num) = x.val
5656

@@ -105,7 +105,6 @@ end
105105
@num_method Base.isless isless(value(a), value(b))
106106
@num_method Base.isequal isequal(value(a), value(b)) (Number, Symbolic)
107107
@num_method Base.:(==) value(a) == value(b) (Number,)
108-
Base.real(x::Num) = Num(real(value(x)))
109108

110109
Base.hash(x::Num, h::UInt) = hash(value(x), h)
111110

src/differentials.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ function expand_derivatives(n::Num, simplify=true; occurances=nothing)
137137
Num(expand_derivatives(value(n), simplify; occurances=occurances))
138138
end
139139

140-
_iszero(x::Number) = iszero(x)
141-
_isone(x::Number) = isone(x)
142140
_iszero(x) = false
143141
_isone(x) = false
144142

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
function calculate_tgrad(sys::AbstractODESystem;
22
simplify=true)
33
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
4+
5+
# We need to remove explicit time dependence on the state because when we
6+
# have `u(t) * t` we want to have the tgrad to be `u(t)` instead of `u'(t) *
7+
# t + u(t)`.
48
rhs = [detime_dvs(eq.rhs) for eq equations(sys)]
59
iv = sys.iv
6-
for r in rhs
7-
@show r
8-
@show expand_derivatives(Differential(iv)(r))
9-
end
1010
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
11-
tgrad = retime_dvs.(notime_tgrad,(states(sys),),iv)
1211
if simplify
13-
tgrad = ModelingToolkit.simplify.(tgrad)
12+
tgrad = ModelingToolkit.simplify.(notime_tgrad)
1413
end
1514
sys.tgrad[] = tgrad
1615
return tgrad

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function ODESystem(eqs, iv=nothing; kwargs...)
102102
# NOTE: this assumes that the order of algebric equations doesn't matter
103103
diffvars = OrderedSet()
104104
allstates = OrderedSet()
105-
ps = OrderedSet{Sym}()
105+
ps = OrderedSet()
106106
# reorder equations such that it is in the form of `diffeq, algeeq`
107107
diffeq = Equation[]
108108
algeeq = Equation[]

src/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@ end
3333

3434
function detime_dvs(op::Term)
3535
if op.op isa Sym
36-
op.op
36+
Sym{Number}(nameof(op.op))
3737
else
3838
Term(op.op,detime_dvs.(op.args))
3939
end
4040
end
4141
detime_dvs(op) = op
4242

43-
function retime_dvs(op::Operation,dvs,iv)
44-
if op.op isa Variable && op.op dvs
45-
Operation(Variable{vartype(op.op)}(op.op.name),Expression[iv])
46-
else
47-
Operation(op.op,retime_dvs.(op.args,(dvs,),iv))
48-
end
43+
function retime_dvs(op::Sym,dvs,iv)
44+
Sym{FnType{Tuple{symtype(iv)}, Number}}(nameof(op))(iv)
45+
end
46+
47+
function retime_dvs(op::Term, dvs, iv)
48+
similarterm(op, op.op, retime_dvs.(op.args,(dvs,),(iv,)))
4949
end
50-
retime_dvs(op::Constant,dvs,iv) = op
50+
retime_dvs(op,dvs,iv) = op
5151

5252
is_constant(::Constant) = true
5353
is_constant(::Any) = false

test/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ de = ODESystem(eqs) # This is broken
5858
ModelingToolkit.calculate_tgrad(de)
5959

6060
tgrad_oop, tgrad_iip = eval.(ModelingToolkit.generate_tgrad(de))
61+
6162
@test tgrad_oop(u,p,t) == [0.0,-u[2],0.0]
6263
du = zeros(3)
6364
tgrad_iip(du,u,p,t)
@@ -82,7 +83,7 @@ tgrad_iip(du,u,p,t)
8283
D(y) ~ x*-z)-y,
8384
D(z) ~ x*y - β*z]
8485
de = ODESystem(eqs)
85-
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β))
86+
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ(t-1), ρ, β))
8687
@test begin
8788
f = eval(generate_function(de, [x,y,z], [σ,ρ,β])[2])
8889
du = [0.0,0.0,0.0]
@@ -92,7 +93,7 @@ tgrad_iip(du,u,p,t)
9293

9394
eqs = [D(x) ~ x + 10σ(t-1) + 100σ(t-2) + 1000σ(t^2)]
9495
de = ODESystem(eqs)
95-
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ,))
96+
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ(t-2),σ(t^2), σ(t-1)))
9697
@test begin
9798
f = eval(generate_function(de, [x], [σ])[2])
9899
du = [0.0]

0 commit comments

Comments
 (0)