Skip to content

Commit 1b957eb

Browse files
shashiYingboMa
andcommitted
Fix state extraction
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 28370ac commit 1b957eb

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ end
157157
renamespace(namespace,name) = Symbol(namespace,:₊,name)
158158

159159
function namespace_variables(sys::AbstractSystem)
160-
[rename(x,renamespace(sys.name,x.name)) for x in states(sys)]
160+
[Sym{symtype(x.op)}(renamespace(sys.name,x.op.name))(x.args...) for x in states(sys)]
161161
end
162162

163163
function namespace_parameters(sys::AbstractSystem)
@@ -190,7 +190,11 @@ end
190190
namespace_expr(O,name,ivname) = O
191191

192192
independent_variable(sys::AbstractSystem) = sys.iv
193-
states(sys::AbstractSystem) = unique(isempty(sys.systems) ? setdiff(sys.states, convert.(Variable,sys.pins)) : [sys.states;reduce(vcat,namespace_variables.(sys.systems))])
193+
function states(sys::AbstractSystem)
194+
unique(isempty(sys.systems) ?
195+
setdiff(sys.states, value.(sys.pins)) :
196+
[sys.states;reduce(vcat,namespace_variables.(sys.systems))])
197+
end
194198
parameters(sys::AbstractSystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
195199
pins(sys::AbstractSystem) = isempty(sys.systems) ? sys.pins : [sys.pins;reduce(vcat,namespace_pins.(sys.systems))]
196200
function observed(sys::AbstractSystem)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,14 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
271271
kwargs...) where iip
272272
dvs = states(sys)
273273
ps = parameters(sys)
274-
u0 = varmap_to_vars(u0map,dvs)
275-
p = varmap_to_vars(parammap,ps)
274+
u0map′ = [lower_varname(value(k), sys.iv) => value(v) for (k, v) in u0map]
275+
parammap′ = [value(k) => value(v) for (k, v) in parammap]
276+
u0 = varmap_to_vars(u0map′,dvs)
277+
if !(parammap isa DiffEqBase.NullParameters)
278+
p = varmap_to_vars(parammap′,ps)
279+
else
280+
p = ps
281+
end
276282
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
277283
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
278284
sparse=sparse,eval_expression=eval_expression,kwargs...)
@@ -306,10 +312,17 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
306312
simplify = true,
307313
linenumbers = false, parallel=SerialForm(),
308314
kwargs...) where iip
315+
309316
dvs = states(sys)
310317
ps = parameters(sys)
311-
u0 = varmap_to_vars(u0map,dvs)
312-
p = varmap_to_vars(parammap,ps)
318+
u0map′ = [lower_varname(value(k), sys.iv) => value(v) for (k, v) in u0map]
319+
parammap′ = [value(k) => value(v) for (k, v) in parammap]
320+
u0 = varmap_to_vars(u0map′,dvs)
321+
if !(parammap isa DiffEqBase.NullParameters)
322+
p = varmap_to_vars(parammap′,ps)
323+
else
324+
p = ps
325+
end
313326
f = ODEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
314327
linenumbers=linenumbers,parallel=parallel,
315328
simplify=simplify,

src/systems/diffeqs/first_order_transform.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ function lower_varname(var::Term, idv, order)
55
return Sym{symtype(var.op)}(name)(var.args[1])
66
end
77

8+
function lower_varname(t::Term, iv)
9+
var, order = var_from_nested_derivative(t)
10+
lower_varname(var, iv, order)
11+
end
12+
lower_varname(t::Sym, iv) = t
13+
814
function flatten_differential(O::Term)
915
@assert is_derivative(O) "invalid differential: $O"
1016
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using SafeTestsets, Test
77
@safetestset "Direct Usage Test" begin include("direct.jl") end
88
@safetestset "System Linearity Test" begin include("linearity.jl") end
99
@safetestset "Build Function Test" begin include("build_function.jl") end
10-
#@safetestset "ODESystem Test" begin include("odesystem.jl") end
10+
@safetestset "ODESystem Test" begin include("odesystem.jl") end
1111
@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end
1212
@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end
1313
@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end

0 commit comments

Comments
 (0)