Skip to content

Commit 5452300

Browse files
authored
Merge pull request #2138 from SciML/myb/clean_clock
Make clock inference more generic
2 parents b4ebc99 + 6a5e70b commit 5452300

File tree

5 files changed

+39
-28
lines changed

5 files changed

+39
-28
lines changed

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4646
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
4747
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4848

49-
[weakdeps]
50-
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
51-
52-
[extensions]
53-
MTKDeepDiffsExt = "DeepDiffs"
54-
5549
[compat]
5650
AbstractTrees = "0.3, 0.4"
5751
ArrayInterface = "6, 7"
@@ -90,6 +84,9 @@ UnPack = "0.1, 1.0"
9084
Unitful = "1.1"
9185
julia = "1.6"
9286

87+
[extensions]
88+
MTKDeepDiffsExt = "DeepDiffs"
89+
9390
[extras]
9491
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
9592
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
@@ -116,3 +113,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
116113

117114
[targets]
118115
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
116+
117+
[weakdeps]
118+
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"

src/clock.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ struct Inferred <: TimeDomain end
77
struct InferredDiscrete <: AbstractDiscrete end
88
struct Continuous <: TimeDomain end
99

10-
const UnknownDomain = Union{Nothing, Inferred, InferredDiscrete}
11-
const InferredDomain = Union{Inferred, InferredDiscrete}
12-
1310
Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain
1411

1512
"""
@@ -101,18 +98,27 @@ end
10198
abstract type AbstractClock <: AbstractDiscrete end
10299

103100
"""
104-
Clock <: AbstractClock
105-
Clock(t; dt)
101+
Clock <: AbstractClock
102+
Clock([t]; dt)
103+
106104
The default periodic clock with independent variables `t` and tick interval `dt`.
107105
If `dt` is left unspecified, it will be inferred (if possible).
108106
"""
109107
struct Clock <: AbstractClock
110108
"Independent variable"
111-
t::Any
109+
t::Union{Nothing, Symbolic}
112110
"Period"
113-
dt::Any
114-
Clock(t, dt = nothing) = new(value(t), dt)
111+
dt::Union{Nothing, Float64}
112+
Clock(t::Union{Num, Symbolic}, dt = nothing) = new(value(t), dt)
113+
Clock(t::Nothing, dt = nothing) = new(t, dt)
115114
end
115+
Clock(dt::Real) = Clock(nothing, dt)
116+
Clock() = Clock(nothing, nothing)
116117

117118
sampletime(c) = isdefined(c, :dt) ? c.dt : nothing
118-
Base.:(==)(c1::Clock, c2::Clock) = isequal(c1.t, c2.t) && c1.dt == c2.dt
119+
Base.hash(c::Clock, seed::UInt) = hash(c.dt, seed 0x953d7a9a18874b90)
120+
function Base.:(==)(c1::Clock, c2::Clock)
121+
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) && c1.dt == c2.dt
122+
end
123+
124+
is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous}

src/discretedomain.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ julia> Δ = Shift(t)
2323
"""
2424
struct Shift <: Operator
2525
"""Fixed Shift"""
26-
t::Any
26+
t::Union{Nothing, Symbolic}
2727
steps::Int
2828
Shift(t, steps = 1) = new(value(t), steps)
2929
end
30+
Shift(steps::Int) = new(nothing, steps)
3031
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
3132
function (D::Shift)(x, allow_zero = false)
3233
!allow_zero && D.steps == 0 && return x
@@ -38,7 +39,7 @@ function (D::Shift)(x::Num, allow_zero = false)
3839
if istree(vt)
3940
op = operation(vt)
4041
if op isa Shift
41-
if isequal(D.t, op.t)
42+
if D.t === nothing || isequal(D.t, op.t)
4243
arg = arguments(vt)[1]
4344
newsteps = D.steps + op.steps
4445
return Num(newsteps == 0 ? arg : Shift(D.t, newsteps)(arg))
@@ -75,7 +76,7 @@ Represents a sample operator. A discrete-time signal is created by sampling a co
7576
7677
# Constructors
7778
`Sample(clock::TimeDomain = InferredDiscrete())`
78-
`Sample(t, dt::Real)`
79+
`Sample([t], dt::Real)`
7980
8081
`Sample(x::Num)`, with a single argument, is shorthand for `Sample()(x)`.
8182

src/systems/clock_inference.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
1-
struct ClockInference
2-
ts::TearingState
1+
struct ClockInference{S}
2+
ts::S
33
eq_domain::Vector{TimeDomain}
44
var_domain::Vector{TimeDomain}
55
inferred::BitSet
66
end
77

8-
function ClockInference(ts::TearingState)
9-
@unpack fullvars, structure = ts
8+
function ClockInference(ts::TransformationState)
9+
@unpack structure = ts
1010
@unpack graph = structure
1111
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
1212
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
1313
inferred = BitSet()
14-
for (i, v) in enumerate(fullvars)
14+
for (i, v) in enumerate(get_fullvars(ts))
1515
d = get_time_domain(v)
16-
if d isa Union{AbstractClock, Continuous}
16+
if is_concrete_time_domain(d)
1717
push!(inferred, i)
18-
dd = d
19-
var_domain[i] = dd
18+
var_domain[i] = d
2019
end
2120
end
2221
ClockInference(ts, eq_domain, var_domain, inferred)
2322
end
2423

2524
function infer_clocks!(ci::ClockInference)
2625
@unpack ts, eq_domain, var_domain, inferred = ci
27-
@unpack fullvars = ts
2826
@unpack graph = ts.structure
27+
fullvars = get_fullvars(ts)
2928
isempty(inferred) && return ci
3029
# TODO: add a graph type to do this lazily
3130
var_graph = SimpleGraph(ndsts(graph))
@@ -78,7 +77,7 @@ end
7877

7978
function split_system(ci::ClockInference)
8079
@unpack ts, eq_domain, var_domain, inferred = ci
81-
@unpack fullvars = ts
80+
fullvars = get_fullvars(ts)
8281
@unpack graph, var_to_diff = ts.structure
8382
continuous_id = Ref(0)
8483
clock_to_id = Dict{TimeDomain, Int}()

src/systems/systemstructure.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export initialize_system_structure, find_linear_equations
3131
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
3232
export dervars_range, diffvars_range, algvars_range
3333
export DiffGraph, complete!
34+
export get_fullvars
3435

3536
struct DiffGraph <: Graphs.AbstractGraph{Int}
3637
primal_to_diff::Vector{Union{Int, Nothing}}
@@ -138,6 +139,8 @@ end
138139
abstract type TransformationState{T} end
139140
abstract type AbstractTearingState{T} <: TransformationState{T} end
140141

142+
get_fullvars(ts::TransformationState) = ts.fullvars
143+
141144
Base.@kwdef mutable struct SystemStructure
142145
# Maps the (index of) a variable to the (index of) the variable describing
143146
# its derivative.
@@ -201,6 +204,8 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
201204
extra_eqs::Vector
202205
end
203206

207+
TransformationState(sys::AbstractSystem) = TearingState(sys)
208+
204209
function Base.show(io::IO, state::TearingState)
205210
print(io, "TearingState of ", typeof(state.sys))
206211
end

0 commit comments

Comments
 (0)