Skip to content

Commit 4f2ab4c

Browse files
shashiYingboMa
andcommitted
unbreak some stuff
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 73ea3b1 commit 4f2ab4c

File tree

7 files changed

+40
-33
lines changed

7 files changed

+40
-33
lines changed

src/build_function.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
8888
end
8989

9090
# Scalar output
91-
function _build_function(target::JuliaTarget, op::Operation, args...;
91+
function _build_function(target::JuliaTarget, op, args...;
9292
conv = toexpr, expression = Val{true},
9393
checkbounds = false,
9494
linenumbers = true, headerfun=addheader)
@@ -215,7 +215,7 @@ Special Keyword Argumnets:
215215
- `fillzeros`: Whether to perform `fill(out,0)` before the calculations to ensure
216216
safety with `skipzeros`.
217217
"""
218-
function _build_function(target::JuliaTarget, rhss, args...;
218+
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
219219
conv = toexpr, expression = Val{true},
220220
checkbounds = false,
221221
linenumbers = false, multithread=nothing,
@@ -447,17 +447,24 @@ end
447447

448448
vars_to_pairs(args) = vars_to_pairs(args[1],args[2])
449449
function vars_to_pairs(name,vs::AbstractArray)
450-
_vs = convert.(Variable,vs)
451-
names = [Symbol(u) for u _vs]
452-
exs = [:($name[$i]) for (i, u) enumerate(_vs)]
453-
names,exs
450+
vs_names = [term_to_symbol(value(u)) for u vs]
451+
exs = [:($name[$i]) for (i, u) enumerate(vs)]
452+
vs_names,exs
453+
end
454+
455+
function term_to_symbol(t::Term)
456+
if operation(t) isa Sym
457+
s = nameof(operation(t))
458+
@show s
459+
else
460+
error("really?")
461+
end
454462
end
455463

464+
term_to_symbol(s::Sym) = nameof(s)
465+
456466
function vars_to_pairs(name,vs)
457-
_vs = convert(Variable,vs)
458-
names = [Symbol(_vs)]
459-
exs = [name]
460-
names,exs
467+
[term_to_symbol(value(vs))], [name]
461468
end
462469

463470
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
@@ -480,7 +487,7 @@ end
480487

481488
function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
482489
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
483-
i = findfirst(x->isequal(x isa Variable ? x.name : x.op.name,var_from_nested_derivative(de.lhs)[1].name),varordering)
490+
i = findfirst(x->isequal(x isa Variable ? term_to_symbol(x) : term_to_symbol(x.op),term_to_symbol(var_from_nested_derivative(de.lhs)[1])),varordering)
484491
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
485492
varordering = varordering,
486493
lhsname = lhsname,

src/systems/abstractsystem.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,7 @@ end
245245
function islinear(sys::AbstractSystem)
246246
rhs = [eq.rhs for eq equations(sys)]
247247

248-
iv = sys.iv
249-
dvs = [dv(iv) for dv states(sys)]
250-
251-
all(islinear(r, dvs) for r in rhs)
248+
all(islinear(r, states(sys)) for r in rhs)
252249
end
253250

254251
function pins(sys::AbstractSystem,args...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@ function calculate_tgrad(sys::AbstractODESystem;
22
simplify=true)
33
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
44
rhs = [detime_dvs(eq.rhs) for eq equations(sys)]
5-
iv = sys.iv()
5+
iv = sys.iv
6+
for r in rhs
7+
@show r
8+
@show expand_derivatives(Differential(iv)(r))
9+
end
610
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
711
tgrad = retime_dvs.(notime_tgrad,(states(sys),),iv)
812
if simplify
913
tgrad = ModelingToolkit.simplify.(tgrad)
1014
end
1115
sys.tgrad[] = tgrad
12-
return @show(tgrad)
16+
return tgrad
1317
end
1418

1519
function calculate_jacobian(sys::AbstractODESystem;
@@ -18,7 +22,7 @@ function calculate_jacobian(sys::AbstractODESystem;
1822
rhs = [eq.rhs for eq equations(sys)]
1923

2024
iv = sys.iv
21-
dvs = [dv(iv) for dv states(sys)]
25+
dvs = states(sys)
2226

2327
if sparse
2428
jac = sparsejacobian(rhs, dvs, simplify=simplify)
@@ -90,7 +94,7 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify=true)
9094
if eq.lhs isa Constant
9195
@assert eq.lhs.value == 0
9296
elseif eq.lhs.op isa Differential
93-
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),dvs)
97+
j = findfirst(x->isequal(term_to_symbol(x),term_to_symbol(var_from_nested_derivative(eq.lhs)[1])),dvs)
9498
M[i,j] = 1
9599
else
96100
error("Only semi-explicit constant mass matrices are currently supported. Faulty equation: $eq.")

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
7171
iv′ = value(iv)
7272
dvs′ = value.(dvs)
7373
ps′ = value.(ps)
74-
tgrad = RefValue(Vector{Expression}(undef, 0))
75-
jac = RefValue{Any}(Matrix{Expression}(undef, 0, 0))
76-
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
77-
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
74+
tgrad = RefValue(Vector{Num}(undef, 0))
75+
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
76+
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
77+
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
7878
ODESystem(deqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
7979
end
8080

src/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function detime_dvs(op::Operation)
35-
if op.op isa Variable
36-
Operation(Variable{vartype(op.op)}(op.op.name),Expression[])
34+
function detime_dvs(op::Term)
35+
if op.op isa Sym
36+
op.op
3737
else
38-
Operation(op.op,detime_dvs.(op.args))
38+
Term(op.op,detime_dvs.(op.args))
3939
end
4040
end
41-
detime_dvs(op::Constant) = op
41+
detime_dvs(op) = op
4242

4343
function retime_dvs(op::Operation,dvs,iv)
4444
if op.op isa Variable && op.op dvs

test/direct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Jiip(J2,[1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
9898

9999
# Function building
100100

101-
@parameters σ() ρ() β()
101+
@parameters σ ρ β
102102
@variables x y z
103103
eqs =*(y-x),
104104
x*-z)-y,
@@ -112,7 +112,7 @@ f(out,[1.0,2,3],[1.0,2,3])
112112
@test all(o1 .== out)
113113

114114
function test_worldage()
115-
@parameters σ() ρ() β()
115+
@parameters σ ρ β
116116
@variables x y z
117117
eqs =*(y-x),
118118
x*-z)-y,

test/odesystem.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ generate_function(de)
2424
function test_diffeq_inference(name, sys, iv, dvs, ps)
2525
@testset "ODESystem construction: $name" begin
2626
@test independent_variable(sys) == value(iv)
27-
@show sys.states
28-
@show @which states(sys)
2927
@test Set(states(sys)) == Set(value.(dvs))
3028
@test Set(parameters(sys)) == Set(value.(ps))
3129
end
@@ -52,10 +50,11 @@ u = SVector(1:3...)
5250
p = SVector(4:6...)
5351
@test f(u, p, 0.1) === @SArray [4, 0, -16]
5452

53+
@show y
5554
eqs = [D(x) ~ σ*(y-x),
5655
D(y) ~ x*-z)-y*t,
5756
D(z) ~ x*y - β*z]
58-
de = ODESystem(eqs)
57+
de = ODESystem(eqs) # This is broken
5958
ModelingToolkit.calculate_tgrad(de)
6059

6160
tgrad_oop, tgrad_iip = eval.(ModelingToolkit.generate_tgrad(de))

0 commit comments

Comments
 (0)