Skip to content

Commit 6f2bc1b

Browse files
Merge pull request #3399 from SciML/bc/external-synchronous
Add support for an external synchronous compiler to discrete and hybrid systems
2 parents 29dd842 + 600d723 commit 6f2bc1b

16 files changed

+542
-109
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ RecursiveArrayTools = "3.26"
148148
Reexport = "0.2, 1"
149149
RuntimeGeneratedFunctions = "0.5.9"
150150
SCCNonlinearSolve = "1.0.0"
151-
SciMLBase = "2.106.0"
151+
SciMLBase = "2.108.0"
152152
SciMLPublic = "1.0.0"
153153
SciMLStructures = "1.7"
154154
Serialization = "1"

src/clock.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
@data InferredClock begin
22
Inferred
3-
InferredDiscrete
3+
InferredDiscrete(Int)
44
end
55

66
const InferredTimeDomain = InferredClock.Type
77
using .InferredClock: Inferred, InferredDiscrete
88

9+
function InferredClock.InferredDiscrete()
10+
return InferredDiscrete(0)
11+
end
12+
913
Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)
1014

1115
struct VariableTimeDomain end
@@ -50,7 +54,7 @@ has_time_domain(x::Num) = has_time_domain(value(x))
5054
has_time_domain(x) = false
5155

5256
for op in [Differential]
53-
@eval input_timedomain(::$op, arg = nothing) = ContinuousClock()
57+
@eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),)
5458
@eval output_timedomain(::$op, arg = nothing) = ContinuousClock()
5559
end
5660

@@ -97,6 +101,7 @@ function is_discrete_domain(x)
97101
end
98102

99103
sampletime(c) = Moshi.Match.@match c begin
104+
x::SciMLBase.AbstractClock => nothing
100105
PeriodicClock(dt) => dt
101106
_ => nothing
102107
end

src/discretedomain.jl

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl
1010
is_transparent_operator(x) = is_transparent_operator(typeof(x))
1111
is_transparent_operator(::Type) = false
1212

13+
"""
14+
$(TYPEDSIGNATURES)
15+
16+
Trait to be implemented for operators which determines whether the operator is applied to
17+
a time-varying quantity and results in a time-varying quantity. For example, `Initial` and
18+
`Pre` are not time-varying since while they are applied to variables, the application
19+
results in a non-discrete-time parameter. `Differential`, `Shift`, `Sample` and `Hold` are
20+
all time-varying operators. All time-varying operators must implement `input_timedomain` and
21+
`output_timedomain`.
22+
"""
23+
is_timevarying_operator(x) = is_timevarying_operator(typeof(x))
24+
is_timevarying_operator(::Type{<:Symbolics.Operator}) = true
25+
is_timevarying_operator(::Type) = false
26+
1327
"""
1428
function SampleTime()
1529
@@ -314,12 +328,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
314328
input_timedomain(op::Operator)
315329
316330
Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on.
331+
Should return a tuple containing the time domain type for each argument to the operator.
317332
"""
318333
function input_timedomain(s::Shift, arg = nothing)
319334
if has_time_domain(arg)
320335
return get_time_domain(arg)
321336
end
322-
InferredDiscrete()
337+
(InferredDiscrete(),)
323338
end
324339

325340
"""
@@ -334,34 +349,53 @@ function output_timedomain(s::Shift, arg = nothing)
334349
InferredDiscrete()
335350
end
336351

337-
input_timedomain(::Sample, _ = nothing) = ContinuousClock()
352+
input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),)
338353
output_timedomain(s::Sample, _ = nothing) = s.clock
339354

340355
function input_timedomain(h::Hold, arg = nothing)
341356
if has_time_domain(arg)
342357
return get_time_domain(arg)
343358
end
344-
InferredDiscrete() # the Hold accepts any discrete
359+
(InferredDiscrete(),) # the Hold accepts any discrete
345360
end
346361
output_timedomain(::Hold, _ = nothing) = ContinuousClock()
347362

348363
sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
349364
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)
350365

351-
changes_domain(op) = isoperator(op, Union{Sample, Hold})
352-
353366
function output_timedomain(x)
354367
if isoperator(x, Operator)
355-
return output_timedomain(operation(x), arguments(x)[])
368+
args = arguments(x)
369+
return output_timedomain(operation(x), if length(args) == 1
370+
args[]
371+
else
372+
args
373+
end)
356374
else
357375
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
358376
end
359377
end
360378

361379
function input_timedomain(x)
362380
if isoperator(x, Operator)
363-
return input_timedomain(operation(x), arguments(x)[])
381+
args = arguments(x)
382+
return input_timedomain(operation(x), if length(args) == 1
383+
args[]
384+
else
385+
args
386+
end)
364387
else
365388
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
366389
end
367390
end
391+
392+
function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...)
393+
return SymbolicContinuousCallback(
394+
[expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing;
395+
affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing,
396+
kwargs..., zero_crossing_id = name)
397+
end
398+
399+
function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback)
400+
return SciMLBase.Clocks.EventClock(cb.zero_crossing_id)
401+
end

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,12 @@ The `Initial` operator. Used by initialization to store constant constraints on
486486
of a system. See the documentation section on initialization for more information.
487487
"""
488488
struct Initial <: Symbolics.Operator end
489+
is_timevarying_operator(::Type{Initial}) = false
489490
Initial(x) = Initial()(x)
490491
SymbolicUtils.promote_symtype(::Type{Initial}, T) = T
491492
SymbolicUtils.isbinop(::Initial) = false
492493
Base.nameof(::Initial) = :Initial
493494
Base.show(io::IO, x::Initial) = print(io, "Initial")
494-
input_timedomain(::Initial, _ = nothing) = ContinuousClock()
495-
output_timedomain(::Initial, _ = nothing) = ContinuousClock()
496495

497496
function (f::Initial)(x)
498497
# wrap output if wrapped input
@@ -1246,6 +1245,7 @@ function namespace_expr(
12461245
O
12471246
end
12481247
end
1248+
12491249
_nonum(@nospecialize x) = x isa Num ? x.val : x
12501250

12511251
"""

src/systems/callbacks.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,11 @@ before the callback is triggered.
143143
"""
144144
struct Pre <: Symbolics.Operator end
145145
Pre(x) = Pre()(x)
146+
is_timevarying_operator(::Type{Pre}) = false
146147
SymbolicUtils.promote_symtype(::Type{Pre}, T) = T
147148
SymbolicUtils.isbinop(::Pre) = false
148149
Base.nameof(::Pre) = :Pre
149150
Base.show(io::IO, x::Pre) = print(io, "Pre")
150-
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
151-
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
152151
unPre(x::Num) = unPre(unwrap(x))
153152
unPre(x::Symbolics.Arr) = unPre(unwrap(x))
154153
unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
@@ -252,6 +251,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
252251
finalize::Union{Affect, SymbolicAffect, Nothing}
253252
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
254253
reinitializealg::SciMLBase.DAEInitializationAlgorithm
254+
zero_crossing_id::Symbol
255255
end
256256

257257
function SymbolicContinuousCallback(
@@ -262,6 +262,7 @@ function SymbolicContinuousCallback(
262262
finalize = nothing,
263263
rootfind = SciMLBase.LeftRootFind,
264264
reinitializealg = nothing,
265+
zero_crossing_id = gensym(),
265266
kwargs...)
266267
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
267268

@@ -278,7 +279,7 @@ function SymbolicContinuousCallback(
278279
SymbolicAffect(affect_neg; kwargs...),
279280
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
280281
finalize; kwargs...),
281-
rootfind, reinitializealg)
282+
rootfind, reinitializealg, zero_crossing_id)
282283
end # Default affect to nothing
283284

284285
function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
@@ -297,7 +298,7 @@ end
297298
function complete(cb::SymbolicContinuousCallback; kwargs...)
298299
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
299300
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
300-
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
301+
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg, cb.zero_crossing_id)
301302
end
302303

303304
make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
@@ -512,7 +513,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
512513
affect_neg = namespace_affects(affect_negs(cb), s),
513514
initialize = namespace_affects(initialize_affects(cb), s),
514515
finalize = namespace_affects(finalize_affects(cb), s),
515-
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg)
516+
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg,
517+
zero_crossing_id = cb.zero_crossing_id)
516518
end
517519

518520
function namespace_conditions(condition, s)
@@ -536,6 +538,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
536538
s = hash(finalize_affects(cb), s)
537539
!is_discrete(cb) && (s = hash(cb.rootfind, s))
538540
hash(cb.reinitializealg, s)
541+
!is_discrete(cb) && (s = hash(cb.zero_crossing_id, s))
542+
return s
539543
end
540544

541545
###########################
@@ -570,13 +574,17 @@ function finalize_affects(cbs::Vector{<:AbstractCallback})
570574
end
571575

572576
function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback)
573-
(is_discrete(e1) === is_discrete(e2)) || return false
574-
(isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) &&
575-
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) &&
576-
isequal(e1.reinitializealg, e2.reinitializealg) ||
577-
return false
578-
is_discrete(e1) ||
579-
(isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind))
577+
is_discrete(e1) === is_discrete(e2) || return false
578+
isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false
579+
isequal(e1.initialize, e2.initialize) || return false
580+
isequal(e1.finalize, e2.finalize) || return false
581+
isequal(e1.reinitializealg, e2.reinitializealg) || return false
582+
if !is_discrete(e1)
583+
isequal(e1.affect_neg, e2.affect_neg) || return false
584+
isequal(e1.rootfind, e2.rootfind) || return false
585+
isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false
586+
end
587+
return true
580588
end
581589

582590
Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)

0 commit comments

Comments
 (0)