diff --git a/docs/pages.jl b/docs/pages.jl index f6c49a0de3..fa09e2b4d9 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -43,7 +43,8 @@ pages = [ "systems/NonlinearSystem.md", "systems/OptimizationSystem.md", "systems/PDESystem.md", - "systems/DiscreteSystem.md"], + "systems/DiscreteSystem.md", + "systems/ImplicitDiscreteSystem.md"], "comparison.md", "internals.md" ] diff --git a/docs/src/systems/DiscreteSystem.md b/docs/src/systems/DiscreteSystem.md index 5ede50c62a..f8a71043ab 100644 --- a/docs/src/systems/DiscreteSystem.md +++ b/docs/src/systems/DiscreteSystem.md @@ -26,3 +26,9 @@ structural_simplify DiscreteProblem(sys::DiscreteSystem, u0map, tspan) DiscreteFunction(sys::DiscreteSystem, args...) ``` + +## Discrete Domain + +```@docs; canonical=false +Shift +``` diff --git a/docs/src/systems/ImplicitDiscreteSystem.md b/docs/src/systems/ImplicitDiscreteSystem.md new file mode 100644 index 0000000000..d69f88f106 --- /dev/null +++ b/docs/src/systems/ImplicitDiscreteSystem.md @@ -0,0 +1,34 @@ +# ImplicitDiscreteSystem + +## System Constructors + +```@docs +ImplicitDiscreteSystem +``` + +## Composition and Accessor Functions + + - `get_eqs(sys)` or `equations(sys)`: The equations that define the implicit discrete system. + - `get_unknowns(sys)` or `unknowns(sys)`: The set of unknowns in the implicit discrete system. + - `get_ps(sys)` or `parameters(sys)`: The parameters of the implicit discrete system. + - `get_iv(sys)`: The independent variable of the implicit discrete system + - `discrete_events(sys)`: The set of discrete events in the implicit discrete system. + +## Transformations + +```@docs; canonical=false +structural_simplify +``` + +## Problem Constructors + +```@docs; canonical=false +ImplicitDiscreteProblem(sys::ImplicitDiscreteSystem, u0map, tspan) +ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...) +``` + +## Discrete Domain + +```@docs; canonical=false +Shift +``` diff --git a/docs/src/tutorials/SampledData.md b/docs/src/tutorials/SampledData.md index a72fd1698b..c700bae5c2 100644 --- a/docs/src/tutorials/SampledData.md +++ b/docs/src/tutorials/SampledData.md @@ -25,8 +25,10 @@ The operators [`Sample`](@ref) and [`Hold`](@ref) are thus providing the interfa The [`ShiftIndex`](@ref) operator is used to refer to past and future values of discrete-time variables. The example below illustrates its use, implementing the discrete-time system ```math -x(k+1) = 0.5x(k) + u(k) -y(k) = x(k) +\begin{align} + x(k+1) &= 0.5x(k) + u(k) \\ + y(k) &= x(k) +\end{align} ``` ```@example clocks @@ -187,3 +189,10 @@ connections = [r ~ sin(t) # reference signal @named cl = ODESystem(connections, t, systems = [f, c, p]) ``` + +```@docs +Sample +Hold +ShiftIndex +Clock +``` diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index e6b8130af3..bbaeea1a40 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -127,6 +127,7 @@ abstract type AbstractTimeIndependentSystem <: AbstractSystem end abstract type AbstractODESystem <: AbstractTimeDependentSystem end abstract type AbstractMultivariateSystem <: AbstractSystem end abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end +abstract type AbstractDiscreteSystem <: AbstractTimeDependentSystem end function independent_variable end @@ -174,6 +175,7 @@ include("systems/diffeqs/modelingtoolkitize.jl") include("systems/diffeqs/basic_transformations.jl") include("systems/discrete_system/discrete_system.jl") +include("systems/discrete_system/implicit_discrete_system.jl") include("systems/jumps/jumpsystem.jl") @@ -235,6 +237,8 @@ export DAEFunctionExpr, DAEProblemExpr export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr export SystemStructure export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr +export ImplicitDiscreteSystem, ImplicitDiscreteProblem, ImplicitDiscreteFunction, + ImplicitDiscreteFunctionExpr export JumpSystem export ODEProblem, SDEProblem export NonlinearFunction, NonlinearFunctionExpr @@ -298,7 +302,7 @@ export debug_system #export ContinuousClock, Discrete, sampletime, input_timedomain, output_timedomain #export has_discrete_domain, has_continuous_domain #export is_discrete_domain, is_continuous_domain, is_hybrid_domain -export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime +export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime, Next, Prev export Clock, SolverStepClock, TimeDomain export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables diff --git a/src/doc b/src/doc new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/doc @@ -0,0 +1 @@ + diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index f0124d7f4b..2e600e86e4 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -65,7 +65,7 @@ export torn_system_jacobian_sparsity export full_equations export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask export computed_highest_diff_variables -export shift2term, lower_shift_varname +export shift2term, lower_shift_varname, simplify_shifts, distribute_shift include("utils.jl") include("pantelides.jl") diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 9818bba361..f1cdd7ce9c 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -558,6 +558,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching end total_sub[simplify_shifts(neweq.lhs)] = neweq.rhs + # Substitute unshifted variables x(k), y(k) on RHS of implicit equations + if is_only_discrete(structure) + var_to_diff[iv] === nothing && (total_sub[var] = neweq.rhs) + end push!(diff_eqs, neweq) push!(diffeq_idxs, ieq) push!(diff_vars, diff_to_var[iv]) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index f3cea9c7ba..14628f2958 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -457,11 +457,10 @@ Handle renaming variable names for discrete structural simplification. Three cas """ function lower_shift_varname(var, iv) op = operation(var) - op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t) - if op.steps < 0 + if op isa Shift return shift2term(var) else - return var + return Shift(iv, 0)(var, true) end end @@ -476,10 +475,14 @@ function shift2term(var) backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps - num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁ - ds = join([Char(0x209c), Char(0x208b), num]) - # Char(0x209c) = ₜ # Char(0x208b) = ₋ (subscripted minus) + # Char(0x208a) = ₊ (subscripted plus) + pm = backshift > 0 ? Char(0x208a) : Char(0x208b) + # subscripted number, e.g. ₁ + num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift)))) + # Char(0x209c) = ₜ + # ds = ₜ₋₁ + ds = join([Char(0x209c), pm, num]) O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg oldop = operation(O) @@ -499,6 +502,9 @@ function isdoubleshift(var) ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift) end +""" +Simplify multiple shifts: Shift(t, k1)(Shift(t, k2)(x)) becomes Shift(t, k1+k2)(x). +""" function simplify_shifts(var) ModelingToolkit.hasshift(var) || return var var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs) @@ -518,3 +524,45 @@ function simplify_shifts(var) unwrap(var).metadata) end end + +""" +Distribute a shift applied to a whole expression or equation. +Shift(t, 1)(x + y) will become Shift(t, 1)(x) + Shift(t, 1)(y). +Only shifts variables whose independent variable is the same t that appears in the Shift (i.e. constants, time-independent parameters, etc. do not get shifted). +""" +function distribute_shift(var) + var = unwrap(var) + var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs) + + ModelingToolkit.hasshift(var) || return var + shift = operation(var) + shift isa Shift || return var + + shift = operation(var) + expr = only(arguments(var)) + if expr isa Equation + return distribute_shift(shift(expr.lhs)) ~ distribute_shift(shift(expr.rhs)) + end + shiftexpr = _distribute_shift(expr, shift) + return simplify_shifts(shiftexpr) +end + +function _distribute_shift(expr, shift) + if iscall(expr) + op = operation(expr) + args = arguments(expr) + + if ModelingToolkit.isvariable(expr) + (length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) : + (return expr) + elseif op isa Shift + return shift(expr) + else + return maketerm( + typeof(expr), operation(expr), Base.Fix2(_distribute_shift, shift).(args), + unwrap(expr).metadata) + end + else + return expr + end +end diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 8359a3e7de..5c7d77ec83 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -17,7 +17,7 @@ eqs = [x(k+1) ~ σ*(y-x), @named de = DiscreteSystem(eqs) ``` """ -struct DiscreteSystem <: AbstractTimeDependentSystem +struct DiscreteSystem <: AbstractDiscreteSystem """ A tag for the system. If two systems have the same tag, then they are structurally identical. @@ -237,6 +237,8 @@ function DiscreteSystem(eqs, iv; kwargs...) collect(allunknowns), collect(new_ps); kwargs...) end +DiscreteSystem(eq::Equation, args...; kwargs...) = DiscreteSystem([eq], args...; kwargs...) + function flatten(sys::DiscreteSystem, noeqs = false) systems = get_systems(sys) if isempty(systems) diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl new file mode 100644 index 0000000000..327c82c47f --- /dev/null +++ b/src/systems/discrete_system/implicit_discrete_system.jl @@ -0,0 +1,440 @@ +""" +$(TYPEDEF) +An implicit system of difference equations. +# Fields +$(FIELDS) +# Example +``` +using ModelingToolkit +using ModelingToolkit: t_nounits as t +@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1 +@variables x(t)=1.0 y(t)=0.0 z(t)=0.0 +k = ShiftIndex(t) +eqs = [x ~ σ*(y-x(k-1)), + y ~ x(k-1)*(ρ-z(k-1))-y, + z ~ x(k-1)*y(k-1) - β*z] +@named ide = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) +``` +""" +struct ImplicitDiscreteSystem <: AbstractDiscreteSystem + """ + A tag for the system. If two systems have the same tag, then they are + structurally identical. + """ + tag::UInt + """The difference equations defining the discrete system.""" + eqs::Vector{Equation} + """Independent variable.""" + iv::BasicSymbolic{Real} + """Dependent (state) variables. Must not contain the independent variable.""" + unknowns::Vector + """Parameter variables. Must not contain the independent variable.""" + ps::Vector + """Time span.""" + tspan::Union{NTuple{2, Any}, Nothing} + """Array variables.""" + var_to_name::Any + """Observed states.""" + observed::Vector{Equation} + """ + The name of the system + """ + name::Symbol + """ + A description of the system. + """ + description::String + """ + The internal systems. These are required to have unique names. + """ + systems::Vector{ImplicitDiscreteSystem} + """ + The default values to use when initial conditions and/or + parameters are not supplied in `ImplicitDiscreteProblem`. + """ + defaults::Dict + """ + The guesses to use as the initial conditions for the + initialization system. + """ + guesses::Dict + """ + The system for performing the initialization. + """ + initializesystem::Union{Nothing, NonlinearSystem} + """ + Extra equations to be enforced during the initialization sequence. + """ + initialization_eqs::Vector{Equation} + """ + Inject assignment statements before the evaluation of the RHS function. + """ + preface::Any + """ + Type of the system. + """ + connector_type::Any + """ + Topologically sorted parameter dependency equations, where all symbols are parameters and + the LHS is a single parameter. + """ + parameter_dependencies::Vector{Equation} + """ + Metadata for the system, to be used by downstream packages. + """ + metadata::Any + """ + Metadata for MTK GUI. + """ + gui_metadata::Union{Nothing, GUIMetadata} + """ + Cache for intermediate tearing state. + """ + tearing_state::Any + """ + Substitutions generated by tearing. + """ + substitutions::Any + """ + If a model `sys` is complete, then `sys.x` no longer performs namespacing. + """ + complete::Bool + """ + Cached data for fast symbolic indexing. + """ + index_cache::Union{Nothing, IndexCache} + """ + The hierarchical parent system before simplification. + """ + parent::Any + isscheduled::Bool + + function ImplicitDiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, + observed, name, description, systems, defaults, guesses, initializesystem, + initialization_eqs, preface, connector_type, parameter_dependencies = Equation[], + metadata = nothing, gui_metadata = nothing, + tearing_state = nothing, substitutions = nothing, + complete = false, index_cache = nothing, parent = nothing, + isscheduled = false; + checks::Union{Bool, Int} = true) + if checks == true || (checks & CheckComponents) > 0 + check_independent_variables([iv]) + check_variables(dvs, iv) + check_parameters(ps, iv) + end + if checks == true || (checks & CheckUnits) > 0 + u = __get_unit_type(dvs, ps, iv) + check_units(u, discreteEqs) + end + new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, description, + systems, defaults, guesses, initializesystem, initialization_eqs, + preface, connector_type, parameter_dependencies, metadata, gui_metadata, + tearing_state, substitutions, complete, index_cache, parent, isscheduled) + end +end + +""" + $(TYPEDSIGNATURES) + +Constructs a ImplicitDiscreteSystem. +""" +function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; + observed = Num[], + systems = ImplicitDiscreteSystem[], + tspan = nothing, + name = nothing, + description = "", + default_u0 = Dict(), + default_p = Dict(), + guesses = Dict(), + initializesystem = nothing, + initialization_eqs = Equation[], + defaults = _merge(Dict(default_u0), Dict(default_p)), + preface = nothing, + connector_type = nothing, + parameter_dependencies = Equation[], + metadata = nothing, + gui_metadata = nothing, + kwargs...) + name === nothing && + throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) + iv′ = value(iv) + dvs′ = value.(dvs) + ps′ = value.(ps) + if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs) + error("Equations in a `ImplicitDiscreteSystem` can only have `Shift` operators.") + end + if !(isempty(default_u0) && isempty(default_p)) + Base.depwarn( + "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", + :ImplicitDiscreteSystem, force = true) + end + + # Copy equations to canonical form, but do not touch array expressions + eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs] + defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) + var_to_name = Dict() + process_variables!(var_to_name, defaults, guesses, dvs′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) + defaults = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(defaults) if v !== nothing) + guesses = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(guesses) if v !== nothing) + + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) + + sysnames = nameof.(systems) + if length(unique(sysnames)) != length(sysnames) + throw(ArgumentError("System names must be unique.")) + end + ImplicitDiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), + eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems, + defaults, guesses, initializesystem, initialization_eqs, preface, connector_type, + parameter_dependencies, metadata, gui_metadata, kwargs...) +end + +function ImplicitDiscreteSystem(eqs, iv; kwargs...) + eqs = collect(eqs) + diffvars = OrderedSet() + allunknowns = OrderedSet() + ps = OrderedSet() + iv = value(iv) + for eq in eqs + collect_vars!(allunknowns, ps, eq, iv; op = Shift) + if iscall(eq.lhs) && operation(eq.lhs) isa Shift + isequal(iv, operation(eq.lhs).t) || + throw(ArgumentError("An ImplicitDiscreteSystem can only have one independent variable.")) + eq.lhs in diffvars && + throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations.")) + push!(diffvars, eq.lhs) + end + end + for eq in get(kwargs, :parameter_dependencies, Equation[]) + if eq isa Pair + collect_vars!(allunknowns, ps, eq, iv) + else + collect_vars!(allunknowns, ps, eq, iv) + end + end + new_ps = OrderedSet() + for p in ps + if iscall(p) && operation(p) === getindex + par = arguments(p)[begin] + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && + all(par[i] in ps for i in eachindex(par)) + push!(new_ps, par) + else + push!(new_ps, p) + end + else + push!(new_ps, p) + end + end + return ImplicitDiscreteSystem(eqs, iv, + collect(allunknowns), collect(new_ps); kwargs...) +end + +function ImplicitDiscreteSystem(eq::Equation, args...; kwargs...) + ImplicitDiscreteSystem([eq], args...; kwargs...) +end + +function flatten(sys::ImplicitDiscreteSystem, noeqs = false) + systems = get_systems(sys) + if isempty(systems) + return sys + else + return ImplicitDiscreteSystem(noeqs ? Equation[] : equations(sys), + get_iv(sys), + unknowns(sys), + parameters(sys), + observed = observed(sys), + defaults = defaults(sys), + guesses = guesses(sys), + initialization_eqs = initialization_equations(sys), + name = nameof(sys), + description = description(sys), + metadata = get_metadata(sys), + checks = false) + end +end + +function generate_function( + sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...) + iv = get_iv(sys) + # Algebraic equations get shifted forward 1, to match with differential equations + exprs = map(equations(sys)) do eq + _iszero(eq.lhs) ? distribute_shift(Shift(iv, 1)(eq.rhs)) : (eq.rhs - eq.lhs) + end + + # Handle observables in algebraic equations, since they are shifted + obs = observed(sys) + shifted_obs = Symbolics.Equation[distribute_shift(Shift(iv, 1)(eq)) for eq in obs] + obsidxs = observed_equations_used_by(sys, exprs; obs = shifted_obs) + extra_assignments = [Assignment(shifted_obs[i].lhs, shifted_obs[i].rhs) + for i in obsidxs] + + u_next = map(Shift(iv, 1), dvs) + u = dvs + build_function_wrapper( + sys, exprs, u_next, u, ps..., iv; p_start = 3, extra_assignments, kwargs...) +end + +function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs) + iv = get_iv(sys) + updated = AnyDict() + for k in collect(keys(u0map)) + v = u0map[k] + if !((op = operation(k)) isa Shift) + isnothing(getunshifted(k)) && + error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).") + + updated[k] = v + elseif op.steps > 0 + error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).") + else + updated[k] = v + end + end + for var in unknowns(sys) + op = operation(var) + root = getunshifted(var) + shift = getshift(var) + isnothing(root) && continue + (haskey(updated, Shift(iv, shift)(root)) || haskey(updated, var)) && continue + haskey(defs, root) || error("Initial condition for $var not provided.") + updated[var] = defs[root] + end + return updated +end + +""" + $(TYPEDSIGNATURES) +Generates an ImplicitDiscreteProblem from an ImplicitDiscreteSystem. +""" +function SciMLBase.ImplicitDiscreteProblem( + sys::ImplicitDiscreteSystem, u0map = [], tspan = get_tspan(sys), + parammap = SciMLBase.NullParameters(); + eval_module = @__MODULE__, + eval_expression = false, + use_union = false, + kwargs... +) + if !iscomplete(sys) + error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`.") + end + dvs = unknowns(sys) + ps = parameters(sys) + eqs = equations(sys) + iv = get_iv(sys) + + u0map = to_varmap(u0map, dvs) + u0map = shift_u0map_forward(sys, u0map, defaults(sys)) + @show u0map + f, u0, p = process_SciMLProblem( + ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...) + @show u0 + + kwargs = filter_kwargs(kwargs) + ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...) +end + +function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...) + ImplicitDiscreteFunction{true}(sys, args...; kwargs...) +end + +function SciMLBase.ImplicitDiscreteFunction{true}( + sys::ImplicitDiscreteSystem, args...; kwargs...) + ImplicitDiscreteFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function SciMLBase.ImplicitDiscreteFunction{false}( + sys::ImplicitDiscreteSystem, args...; kwargs...) + ImplicitDiscreteFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end +function SciMLBase.ImplicitDiscreteFunction{iip, specialize}( + sys::ImplicitDiscreteSystem, + dvs = unknowns(sys), + ps = parameters(sys), + u0 = nothing; + version = nothing, + p = nothing, + t = nothing, + eval_expression = false, + eval_module = @__MODULE__, + analytic = nothing, + kwargs...) where {iip, specialize} + if !iscomplete(sys) + error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`") + end + f_gen = generate_function(sys, dvs, ps; expression = Val{true}, + expression_module = eval_module, kwargs...) + f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) + f(u_next, u, p, t) = f_oop(u_next, u, p, t) + f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t) + + if specialize === SciMLBase.FunctionWrapperSpecialize && iip + if u0 === nothing || p === nothing || t === nothing + error("u0, p, and t must be specified for FunctionWrapperSpecialize on ImplicitDiscreteFunction.") + end + f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t)) + end + + observedfun = ObservedFunctionCache( + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + + ImplicitDiscreteFunction{iip, specialize}(f; + sys = sys, + observed = observedfun, + analytic = analytic, + kwargs...) +end + +""" +```julia +ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = states(sys), + ps = parameters(sys); + version = nothing, + kwargs...) where {iip} +``` + +Create a Julia expression for an `ImplicitDiscreteFunction` from the [`ImplicitDiscreteSystem`](@ref). +The arguments `dvs` and `ps` are used to set the order of the dependent +variable and parameter vectors, respectively. +""" +struct ImplicitDiscreteFunctionExpr{iip} end +struct ImplicitDiscreteFunctionClosure{O, I} <: Function + f_oop::O + f_iip::I +end +(f::ImplicitDiscreteFunctionClosure)(u_next, u, p, t) = f.f_oop(u_next, u, p, t) +function (f::ImplicitDiscreteFunctionClosure)(resid, u_next, u, p, t) + f.f_iip(resid, u_next, u, p, t) +end + +function ImplicitDiscreteFunctionExpr{iip}( + sys::ImplicitDiscreteSystem, dvs = unknowns(sys), + ps = parameters(sys), u0 = nothing; + version = nothing, p = nothing, + linenumbers = false, + simplify = false, + kwargs...) where {iip} + f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...) + + fsym = gensym(:f) + _f = :($fsym = $ImplicitDiscreteFunctionClosure($f_oop, $f_iip)) + + ex = quote + $_f + ImplicitDiscreteFunction{$iip}($fsym) + end + !linenumbers ? Base.remove_linenums!(ex) : ex +end + +function ImplicitDiscreteFunctionExpr(sys::ImplicitDiscreteSystem, args...; kwargs...) + ImplicitDiscreteFunctionExpr{true}(sys, args...; kwargs...) +end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 6151ffa515..9664e38cb6 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -42,8 +42,8 @@ function structural_simplify( if newsys isa DiscreteSystem && any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys)) error(""" - Encountered algebraic equations when simplifying discrete system. This is \ - not yet supported. + Encountered algebraic equations when simplifying discrete system. Please construct \ + an ImplicitDiscreteSystem instead. """) end for pass in additional_passes diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 0b755bd22d..c0c4a5ff4d 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -438,7 +438,7 @@ function TearingState(sys; quick_cancel = false, check = true) ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), - complete(graph), nothing, var_types, sys isa DiscreteSystem), + complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem), Any[]) if sys isa DiscreteSystem ts = shift_discrete_system(ts) diff --git a/test/discrete_system.jl b/test/discrete_system.jl index eea0ffc36b..874d045aa9 100644 --- a/test/discrete_system.jl +++ b/test/discrete_system.jl @@ -257,7 +257,7 @@ end @variables x(t) y(t) k = ShiftIndex(t) @named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t) -@test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys) +@test_throws ["algebraic equations", "ImplicitDiscreteSystem"] structural_simplify(sys) @testset "Passing `nothing` to `u0`" begin @variables x(t) = 1 diff --git a/test/implicit_discrete_system.jl b/test/implicit_discrete_system.jl new file mode 100644 index 0000000000..932b6c6981 --- /dev/null +++ b/test/implicit_discrete_system.jl @@ -0,0 +1,75 @@ +using ModelingToolkit, Test +using ModelingToolkit: t_nounits as t +using StableRNGs + +k = ShiftIndex(t) +rng = StableRNG(22525) + +@testset "Correct ImplicitDiscreteFunction" begin + @variables x(t) = 1 + @mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k) * x(k - 1) - 3], t) + tspan = (0, 10) + + # u[2] - u_next[1] + # -3 - u_next[2] + u_next[2]*u_next[1] + f = ImplicitDiscreteFunction(sys) + u_next = [3.0, 1.5] + @test f(u_next, [2.0, 3.0], [], t) ≈ [0.0, 0.0] + u_next = [0.0, 0.0] + @test f(u_next, [2.0, 3.0], [], t) ≈ [3.0, -3.0] + + resid = rand(2) + f(resid, u_next, [2.0, 3.0], [], t) + @test resid ≈ [3.0, -3.0] + + prob = ImplicitDiscreteProblem(sys, [x(k - 1) => 3.0], tspan) + @test prob.u0 == [3.0, 1.0] + prob = ImplicitDiscreteProblem(sys, [], tspan) + @test prob.u0 == [1.0, 1.0] + @variables x(t) + @mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k) * x(k - 1) - 3], t) + @test_throws ErrorException prob=ImplicitDiscreteProblem(sys, [], tspan) +end + +@testset "System with algebraic equations" begin + @variables x(t) y(t) + eqs = [x(k) ~ x(k - 1) + x(k - 2), + x^2 ~ 1 - y^2] + @mtkbuild sys = ImplicitDiscreteSystem(eqs, t) + f = ImplicitDiscreteFunction(sys) + + function correct_f(u_next, u, p, t) + [u[2] - u_next[1], + u[1] + u[2] - u_next[2], + 1 - (u_next[1] + u_next[2])^2 - u_next[3]^2] + end + + for _ in 1:10 + u_next = rand(rng, 3) + u = rand(rng, 3) + @test correct_f(u_next, u, [], 0.0) ≈ f(u_next, u, [], 0.0) + end + + # Initialization is satisfied. + prob = ImplicitDiscreteProblem( + sys, [x(k - 1) => 0.3, x(k - 2) => 0.4], (0, 10), guesses = [y => 1]) + @test length(equations(prob.f.initialization_data.initializeprob.f.sys)) == 1 +end + +@testset "Handle observables in function codegen" begin + # Observable appears in differential equation + @variables x(t) y(t) z(t) + eqs = [x(k) ~ x(k - 1) + x(k - 2), + y(k) ~ x(k) + x(k - 2) * z(k - 1), + x + y + z ~ 2] + @mtkbuild sys = ImplicitDiscreteSystem(eqs, t) + @test length(unknowns(sys)) == length(equations(sys)) == 3 + @test occursin("var\"y(t)\"", string(ImplicitDiscreteFunctionExpr(sys))) + + # Shifted observable that appears in algebraic equation is properly handled. + eqs = [z(k) ~ x(k) + sin(x(k)), + y(k) ~ x(k - 1) + x(k - 2), + z(k) * x(k) ~ 3] + @mtkbuild sys = ImplicitDiscreteSystem(eqs, t) + @test occursin("var\"Shift(t, 1)(z(t))\"", string(ImplicitDiscreteFunctionExpr(sys))) +end diff --git a/test/pde.jl b/test/pdesystem.jl similarity index 100% rename from test/pde.jl rename to test/pdesystem.jl diff --git a/test/runtests.jl b/test/runtests.jl index 0ae66d3755..e0ef4b8640 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,12 +80,13 @@ end @safetestset "Variable Metadata Test" include("test_variable_metadata.jl") @safetestset "OptimizationSystem Test" include("optimizationsystem.jl") @safetestset "Discrete System" include("discrete_system.jl") + @safetestset "Implicit Discrete System" include("implicit_discrete_system.jl") @safetestset "SteadyStateSystem Test" include("steadystatesystems.jl") @safetestset "SDESystem Test" include("sdesystem.jl") @safetestset "DDESystem Test" include("dde.jl") @safetestset "NonlinearSystem Test" include("nonlinearsystem.jl") @safetestset "SCCNonlinearProblem Test" include("scc_nonlinear_problem.jl") - @safetestset "PDE Construction Test" include("pde.jl") + @safetestset "PDE Construction Test" include("pdesystem.jl") @safetestset "JumpSystem Test" include("jumpsystem.jl") @safetestset "BVProblem Test" include("bvproblem.jl") @safetestset "print_tree" include("print_tree.jl") diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 24cfb98d45..9bf239a574 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -5,6 +5,7 @@ using SparseArrays using UnPack using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm using Symbolics: unwrap +const ST = StructuralTransformations # Define some variables @parameters L g @@ -163,6 +164,29 @@ end @test value[] == 1 end +@testset "Distribute shifts" begin + @variables x(t) y(t) z(t) + @parameters a b c + k = ShiftIndex(t) + + # Expand shifts + @test isequal( + ST.distribute_shift(Shift(t, -1)(x + y)), Shift(t, -1)(x) + Shift(t, -1)(y)) + + expr = a * Shift(t, -2)(x) + Shift(t, 2)(y) + b + @test isequal(ST.simplify_shifts(ST.distribute_shift(Shift(t, 2)(expr))), + a * x + Shift(t, 4)(y) + b) + @test isequal(ST.distribute_shift(Shift(t, 2)(exp(z))), exp(Shift(t, 2)(z))) + @test isequal(ST.distribute_shift(Shift(t, 2)(exp(a) + b)), exp(a) + b) + + expr = a^x - log(b * y) + z * x + @test isequal(ST.distribute_shift(Shift(t, -3)(expr)), + a^(Shift(t, -3)(x)) - log(b * Shift(t, -3)(y)) + Shift(t, -3)(z) * Shift(t, -3)(x)) + + expr = x(k + 1) ~ x + x(k - 1) + @test isequal(ST.distribute_shift(Shift(t, -1)(expr)), x ~ x(k - 1) + x(k - 2)) +end + @testset "`map_variables_to_equations`" begin @testset "Not supported for systems without `.tearing_state`" begin @variables x