Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1ff286d
Early work on the new discrete backend for MTK
BenChung Feb 18, 2025
ae8dd96
feat: retain original equations of the system in `TearingState`
AayushSabharwal Feb 20, 2025
0778868
feat: allow namespacing statemachine equations
AayushSabharwal Mar 14, 2025
295a4ce
feat: propagate state machines in structural simplification
AayushSabharwal Mar 14, 2025
6c0b055
Handle nothing updates better
BenChung Mar 15, 2025
069dc74
Redefine the discrete_compile interface a bit
BenChung Mar 15, 2025
3fba60e
Change the external synchronous signature to include the id/clock map
BenChung May 14, 2025
dd1746e
feat: add `zero_crossing_id` to `SymbolicContinuousCallback`
AayushSabharwal Jun 20, 2025
af42247
feat: add `ZeroCrossing` and `EventClock` from zero crossing
AayushSabharwal Jun 20, 2025
8592fe2
feat: subset variables appropriately in clock inference
AayushSabharwal Jun 27, 2025
d023478
feat: add hook during problem construction
AayushSabharwal Jun 27, 2025
0f3cace
fix: handle `Union` types in `BufferTemplate`
AayushSabharwal Jul 8, 2025
8b7e1f7
feat: rewrite clock inference to support polyadic synchronous operators
AayushSabharwal Jul 9, 2025
bea6a41
Better support for multi-adic operators
BenChung Jul 11, 2025
00d7ae0
refactor: replace `is_synchronous_operator` with `is_timevarying_oper…
AayushSabharwal Jul 14, 2025
2906c01
fix: fix `is_time_domain_conversion` for new `input_timedomain`
AayushSabharwal Jul 14, 2025
5cbf8a7
fix: fix `input_timedomain` implementation for `Differential`
AayushSabharwal Jul 14, 2025
9bb1a42
build: bump SciMLBase compat
AayushSabharwal Aug 6, 2025
5efe374
feat: use and split initialization equations in clock inference
AayushSabharwal Aug 6, 2025
38a1d73
refactor: clock partition with no clock is assumed continuous
AayushSabharwal Aug 7, 2025
a2e318e
refactor: do not run clock inference for time-independent systems
AayushSabharwal Aug 7, 2025
600d723
fix: fix tests to account for new `zero_crossing_id`
AayushSabharwal Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.106.0"
SciMLBase = "2.108.0"
SciMLPublic = "1.0.0"
SciMLStructures = "1.7"
Serialization = "1"
Expand Down
9 changes: 7 additions & 2 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
@data InferredClock begin
Inferred
InferredDiscrete
InferredDiscrete(Int)
end

const InferredTimeDomain = InferredClock.Type
using .InferredClock: Inferred, InferredDiscrete

function InferredClock.InferredDiscrete()
return InferredDiscrete(0)
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

struct VariableTimeDomain end
Expand Down Expand Up @@ -50,7 +54,7 @@ has_time_domain(x::Num) = has_time_domain(value(x))
has_time_domain(x) = false

for op in [Differential]
@eval input_timedomain(::$op, arg = nothing) = ContinuousClock()
@eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),)
@eval output_timedomain(::$op, arg = nothing) = ContinuousClock()
end

Expand Down Expand Up @@ -97,6 +101,7 @@ function is_discrete_domain(x)
end

sampletime(c) = Moshi.Match.@match c begin
x::SciMLBase.AbstractClock => nothing
PeriodicClock(dt) => dt
_ => nothing
end
Expand Down
48 changes: 41 additions & 7 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl
is_transparent_operator(x) = is_transparent_operator(typeof(x))
is_transparent_operator(::Type) = false

"""
$(TYPEDSIGNATURES)

Trait to be implemented for operators which determines whether the operator is applied to
a time-varying quantity and results in a time-varying quantity. For example, `Initial` and
`Pre` are not time-varying since while they are applied to variables, the application
results in a non-discrete-time parameter. `Differential`, `Shift`, `Sample` and `Hold` are
all time-varying operators. All time-varying operators must implement `input_timedomain` and
`output_timedomain`.
"""
is_timevarying_operator(x) = is_timevarying_operator(typeof(x))
is_timevarying_operator(::Type{<:Symbolics.Operator}) = true
is_timevarying_operator(::Type) = false

"""
function SampleTime()

Expand Down Expand Up @@ -314,12 +328,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
input_timedomain(op::Operator)

Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on.
Should return a tuple containing the time domain type for each argument to the operator.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete()
(InferredDiscrete(),)
end

"""
Expand All @@ -334,34 +349,53 @@ function output_timedomain(s::Shift, arg = nothing)
InferredDiscrete()
end

input_timedomain(::Sample, _ = nothing) = ContinuousClock()
input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),)
output_timedomain(s::Sample, _ = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete() # the Hold accepts any discrete
(InferredDiscrete(),) # the Hold accepts any discrete
end
output_timedomain(::Hold, _ = nothing) = ContinuousClock()

sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)

changes_domain(op) = isoperator(op, Union{Sample, Hold})

function output_timedomain(x)
if isoperator(x, Operator)
return output_timedomain(operation(x), arguments(x)[])
args = arguments(x)
return output_timedomain(operation(x), if length(args) == 1
args[]
else
args
end)
else
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
end
end

function input_timedomain(x)
if isoperator(x, Operator)
return input_timedomain(operation(x), arguments(x)[])
args = arguments(x)
return input_timedomain(operation(x), if length(args) == 1
args[]
else
args
end)
else
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
end
end

function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...)
return SymbolicContinuousCallback(
[expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing;
affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing,
kwargs..., zero_crossing_id = name)
end

function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback)
return SciMLBase.Clocks.EventClock(cb.zero_crossing_id)
end
4 changes: 2 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,12 @@ The `Initial` operator. Used by initialization to store constant constraints on
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
is_timevarying_operator(::Type{Initial}) = false
Initial(x) = Initial()(x)
SymbolicUtils.promote_symtype(::Type{Initial}, T) = T
SymbolicUtils.isbinop(::Initial) = false
Base.nameof(::Initial) = :Initial
Base.show(io::IO, x::Initial) = print(io, "Initial")
input_timedomain(::Initial, _ = nothing) = ContinuousClock()
output_timedomain(::Initial, _ = nothing) = ContinuousClock()

function (f::Initial)(x)
# wrap output if wrapped input
Expand Down Expand Up @@ -1246,6 +1245,7 @@ function namespace_expr(
O
end
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
32 changes: 20 additions & 12 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,11 @@ before the callback is triggered.
"""
struct Pre <: Symbolics.Operator end
Pre(x) = Pre()(x)
is_timevarying_operator(::Type{Pre}) = false
SymbolicUtils.promote_symtype(::Type{Pre}, T) = T
SymbolicUtils.isbinop(::Pre) = false
Base.nameof(::Pre) = :Pre
Base.show(io::IO, x::Pre) = print(io, "Pre")
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
unPre(x::Num) = unPre(unwrap(x))
unPre(x::Symbolics.Arr) = unPre(unwrap(x))
unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
Expand Down Expand Up @@ -252,6 +251,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
finalize::Union{Affect, SymbolicAffect, Nothing}
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
reinitializealg::SciMLBase.DAEInitializationAlgorithm
zero_crossing_id::Symbol
end

function SymbolicContinuousCallback(
Expand All @@ -262,6 +262,7 @@ function SymbolicContinuousCallback(
finalize = nothing,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = nothing,
zero_crossing_id = gensym(),
kwargs...)
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

Expand All @@ -278,7 +279,7 @@ function SymbolicContinuousCallback(
SymbolicAffect(affect_neg; kwargs...),
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
finalize; kwargs...),
rootfind, reinitializealg)
rootfind, reinitializealg, zero_crossing_id)
end # Default affect to nothing

function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
Expand All @@ -297,7 +298,7 @@ end
function complete(cb::SymbolicContinuousCallback; kwargs...)
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg, cb.zero_crossing_id)
end

make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
Expand Down Expand Up @@ -512,7 +513,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
affect_neg = namespace_affects(affect_negs(cb), s),
initialize = namespace_affects(initialize_affects(cb), s),
finalize = namespace_affects(finalize_affects(cb), s),
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg)
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg,
zero_crossing_id = cb.zero_crossing_id)
end

function namespace_conditions(condition, s)
Expand All @@ -536,6 +538,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
s = hash(finalize_affects(cb), s)
!is_discrete(cb) && (s = hash(cb.rootfind, s))
hash(cb.reinitializealg, s)
!is_discrete(cb) && (s = hash(cb.zero_crossing_id, s))
return s
end

###########################
Expand Down Expand Up @@ -570,13 +574,17 @@ function finalize_affects(cbs::Vector{<:AbstractCallback})
end

function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback)
(is_discrete(e1) === is_discrete(e2)) || return false
(isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) &&
isequal(e1.reinitializealg, e2.reinitializealg) ||
return false
is_discrete(e1) ||
(isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind))
is_discrete(e1) === is_discrete(e2) || return false
isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false
isequal(e1.initialize, e2.initialize) || return false
isequal(e1.finalize, e2.finalize) || return false
isequal(e1.reinitializealg, e2.reinitializealg) || return false
if !is_discrete(e1)
isequal(e1.affect_neg, e2.affect_neg) || return false
isequal(e1.rootfind, e2.rootfind) || return false
isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false
end
return true
end

Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)
Expand Down
Loading
Loading