Skip to content

Commit afc4689

Browse files
committed
add test file
1 parent 4142923 commit afc4689

File tree

6 files changed

+44
-16
lines changed

6 files changed

+44
-16
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ include("systems/diffeqs/modelingtoolkitize.jl")
164164
include("systems/diffeqs/basic_transformations.jl")
165165

166166
include("systems/discrete_system/discrete_system.jl")
167+
include("systems/discrete_system/implicit_discrete_system.jl")
167168

168169
include("systems/jumps/jumpsystem.jl")
169170

@@ -229,6 +230,7 @@ export DAEFunctionExpr, DAEProblemExpr
229230
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
230231
export SystemStructure
231232
export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
233+
export ImplicitDiscreteSystem, ImplicitDiscreteProblem, ImplicitDiscreteFunction, ImplicitDiscreteFunctionExpr
232234
export JumpSystem
233235
export ODEProblem, SDEProblem
234236
export NonlinearFunction, NonlinearFunctionExpr

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
233233
push!(new_ps, p)
234234
end
235235
end
236+
@show allunknowns
236237
return DiscreteSystem(eqs, iv,
237238
collect(allunknowns), collect(new_ps); kwargs...)
238239
end

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ using ModelingToolkit: t_nounits as t
1010
@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
1111
@variables x(t)=1.0 y(t)=0.0 z(t)=0.0
1212
k = ShiftIndex(t)
13-
eqs = [x(k+1) ~ σ*(y-x),
14-
y(k+1) ~ x*(ρ-z)-y,
15-
z(k+1) ~ x*y - β*z]
16-
@named de = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or
17-
@named de = ImplicitDiscreteSystem(eqs)
13+
eqs = [x ~ σ*(y(k-1)-x),
14+
y ~ x*(ρ-z(k-1))-y,
15+
z ~ x(k-1)*y - β*z]
16+
@named ide = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
1817
```
1918
"""
2019
struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
@@ -136,6 +135,7 @@ end
136135

137136
"""
138137
$(TYPEDSIGNATURES)
138+
139139
Constructs a ImplicitDiscreteSystem.
140140
"""
141141
function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
@@ -170,6 +170,8 @@ function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
170170
:ImplicitDiscreteSystem, force = true)
171171
end
172172

173+
# Copy equations to canonical form, but do not touch array expressions
174+
eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]
173175
defaults = Dict{Any, Any}(todict(defaults))
174176
guesses = Dict{Any, Any}(todict(guesses))
175177
var_to_name = Dict()
@@ -236,6 +238,8 @@ function ImplicitDiscreteSystem(eqs, iv; kwargs...)
236238
return ImplicitDiscreteSystem(eqs, iv,
237239
collect(allunknowns), collect(new_ps); kwargs...)
238240
end
241+
# basically at every timestep it should build a nonlinear solve
242+
# Previous timesteps should be treated as parameters? is this right?
239243

240244
function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
241245
systems = get_systems(sys)
@@ -259,10 +263,25 @@ end
259263

260264
function generate_function(
261265
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
262-
exprs = [eq.rhs for eq in equations(sys)]
263-
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
264-
wrap_parameter_dependencies(sys, false)
265-
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
266+
if !iscomplete(sys)
267+
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
268+
end
269+
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
270+
isscalar = !(exprs isa AbstractArray)
271+
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
272+
if postprocess_fbody === nothing
273+
postprocess_fbody = pre
274+
end
275+
if states === nothing
276+
states = sol_states
277+
end
278+
exprs = [eq.lhs - eq.rhs for eq in equations(sys)]
279+
u = map(Shift(iv, -1), dvs)
280+
u_next = dvs
281+
282+
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘ wrap_parameter_dependencies(sys, false)
283+
284+
build_function(exprs, u_next, u, p..., get_iv(sys))
266285
end
267286

268287
function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
@@ -311,7 +330,7 @@ function SciMLBase.ImplicitDiscreteProblem(
311330
f, u0, p = process_SciMLProblem(
312331
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
313332
u0 = f(u0, p, tspan[1])
314-
ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
333+
NonlinearProblem(f, u0, tspan, p; kwargs...)
315334
end
316335

317336
function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...)
@@ -337,14 +356,15 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
337356
eval_module = @__MODULE__,
338357
analytic = nothing,
339358
kwargs...) where {iip, specialize}
359+
340360
if !iscomplete(sys)
341361
error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
342362
end
343363
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
344364
expression_module = eval_module, kwargs...)
345365
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
346-
f(u, p, t) = f_oop(u, p, t)
347-
f(du, u, p, t) = f_iip(du, u, p, t)
366+
f(u_next, u, p, t) = f_oop(u_next, u, p, t)
367+
f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t)
348368

349369
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
350370
if u0 === nothing || p === nothing || t === nothing
@@ -379,8 +399,8 @@ struct ImplicitDiscreteFunctionClosure{O, I} <: Function
379399
f_oop::O
380400
f_iip::I
381401
end
382-
(f::ImplicitDiscreteFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
383-
(f::ImplicitDiscreteFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
402+
(f::ImplicitDiscreteFunctionClosure)(u_next, u, p, t) = f.f_oop(u_next, u, p, t)
403+
(f::ImplicitDiscreteFunctionClosure)(resid, u_next, u, p, t) = f.f_iip(resid, u_next, u, p, t)
384404

385405
function ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = unknowns(sys),
386406
ps = parameters(sys), u0 = nothing;

src/systems/systemstructure.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ function TearingState(sys; quick_cancel = false, check = true)
432432
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
433433
complete(graph), nothing, var_types, sys isa DiscreteSystem),
434434
Any[])
435-
if sys isa DiscreteSystem
435+
if sys isa DiscreteSystem || sys isa ImplicitDiscreteSystem
436436
ts = shift_discrete_system(ts)
437437
end
438438
return ts
@@ -456,6 +456,8 @@ function lower_order_var(dervar, t)
456456
diffvar
457457
end
458458

459+
"""
460+
"""
459461
function shift_discrete_system(ts::TearingState)
460462
@unpack fullvars, sys = ts
461463
discvars = OrderedSet()

test/implicit_discrete_system.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
#init

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ end
7979
@safetestset "Variable Utils Test" include("variable_utils.jl")
8080
@safetestset "Variable Metadata Test" include("test_variable_metadata.jl")
8181
@safetestset "OptimizationSystem Test" include("optimizationsystem.jl")
82-
@safetestset "Discrete System" include("discrete_system.jl")
82+
@safetestset "DiscreteSystem Test" include("discrete_system.jl")
83+
@safetestset "ImplicitDiscreteSystem Test" include("implicit_discrete_system.jl")
8384
@safetestset "SteadyStateSystem Test" include("steadystatesystems.jl")
8485
@safetestset "SDESystem Test" include("sdesystem.jl")
8586
@safetestset "DDESystem Test" include("dde.jl")

0 commit comments

Comments
 (0)