Skip to content

Convert Moshi.jl to optional extension #1094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -35,7 +36,6 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
[weakdeps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand All @@ -47,6 +47,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
SciMLBaseMLStyleExt = "MLStyle"
SciMLBaseMakieExt = "Makie"
SciMLBaseMoshiExt = "Moshi"
SciMLBasePartialFunctionsExt = "PartialFunctions"
SciMLBasePyCallExt = "PyCall"
SciMLBasePythonCallExt = "PythonCall"
Expand Down
4 changes: 2 additions & 2 deletions ext/SciMLBaseMLStyleExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module SciMLBaseMLStyleExt
using SciMLBase: TimeDomain, ContinuousClock, SolverStepClock, PeriodicClock
using MLStyle: MLStyle
using MLStyle.AbstractPatterns: literal, wildcard, PComp, BasicPatterns, decons
using Moshi.Data: isa_variant
# Removed Moshi dependency - using basic struct detection instead

# This makes Singletons also work without parentheses in matches
MLStyle.is_enum(::Type{ContinuousClock}) = true
Expand All @@ -26,7 +26,7 @@ function MLStyle.pattern_uncall(::Type{SolverStepClock}, self::Function, _, _, _
end

function periodic_clock_pattern(c)
if c isa TimeDomain && isa_variant(c, PeriodicClock)
if c isa PeriodicClock
(c.dt, c.phase)
else
# These values are used in match results, but they shouldn't.
Expand Down
69 changes: 69 additions & 0 deletions ext/SciMLBaseMoshiExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module SciMLBaseMoshiExt

using SciMLBase
using Moshi.Data: @data
using Moshi.Match: @match

# When Moshi is available, override the basic implementations with @match-based versions
# This provides the enhanced pattern matching functionality

function SciMLBase.isclock(c::SciMLBase.TimeDomain)
@match c begin
SciMLBase.PeriodicClock() => true
_ => false
end
end

function SciMLBase.issolverstepclock(c::SciMLBase.TimeDomain)
@match c begin
SciMLBase.SolverStepClock() => true
_ => false
end
end

function SciMLBase.iscontinuous(c::SciMLBase.TimeDomain)
@match c begin
SciMLBase.ContinuousClock() => true
_ => false
end
end

function SciMLBase.first_clock_tick_time(c::SciMLBase.TimeDomain, t0)
@match c begin
SciMLBase.PeriodicClock(dt) => ceil(t0 / dt) * dt
SciMLBase.SolverStepClock() => t0
SciMLBase.ContinuousClock() => error("ContinuousClock() is not a discrete clock")
end
end

function SciMLBase.canonicalize_indexed_clock(ic::SciMLBase.IndexedClock, sol::SciMLBase.AbstractTimeseriesSolution)
c = ic.clock

return @match c begin
SciMLBase.PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
SciMLBase.SolverStepClock() => begin
ssc_idx = findfirst(eachindex(sol.discretes)) do i
!isa(sol.discretes[i].t, AbstractRange)
end
sol.discretes[ssc_idx].t[ic.idx]
end
SciMLBase.ContinuousClock() => sol.t[ic.idx]
end
end

# Also define Moshi-based types for users who want the original @data experience
@data MoshiClocks begin
MoshiContinuousClock
struct MoshiPeriodicClock
dt::Union{Nothing, Float64, Rational{Int}}
phase::Float64 = 0.0
end
MoshiSolverStepClock
end

# Convenience constructors for the Moshi types
MoshiClock(dt::Union{<:Rational, Float64}; phase = 0.0) = MoshiPeriodicClock(dt, phase)
MoshiClock(dt; phase = 0.0) = MoshiPeriodicClock(convert(Float64, dt), phase)
MoshiClock(; phase = 0.0) = MoshiPeriodicClock(nothing, phase)

end
4 changes: 1 addition & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import RuntimeGeneratedFunctions
import EnumX
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset, @delete, @insert
using Moshi.Data: @data
using Moshi.Match: @match
import StaticArraysCore
import Adapt: adapt_structure, adapt

Expand Down Expand Up @@ -862,7 +860,7 @@ export step!, deleteat!, addat!, get_tmp_cache,

export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback

export Clocks, TimeDomain, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous
export Clocks, TimeDomain, Clock, Continuous, ContinuousClock, PeriodicClock, SolverStepClock, IndexedClock, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous, first_clock_tick_time, canonicalize_indexed_clock

export ODEAliasSpecifier, LinearAliasSpecifier

Expand Down
113 changes: 55 additions & 58 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,76 +1,70 @@
@data Clocks begin
ContinuousClock
struct PeriodicClock
# Clock functionality requires Moshi extension - original implementation was:
#
# @data Clocks begin
# ContinuousClock
# struct PeriodicClock
# dt::Union{Nothing, Float64, Rational{Int}}
# phase::Float64 = 0.0
# end
# SolverStepClock
# end
#
# This is recreated by the SciMLBaseMoshiExt when Moshi is loaded

# Create a module with the same structure as the original, but with stub implementations
module Clocks
abstract type Type end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal yeah this isn't valid right because it needs the macro?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah


struct ContinuousClock <: Type end
struct SolverStepClock <: Type end

struct PeriodicClock <: Type
dt::Union{Nothing, Float64, Rational{Int}}
phase::Float64 = 0.0
phase::Float64
PeriodicClock(dt, phase = 0.0) = new(dt, phase)
end
SolverStepClock

# Keyword constructor
PeriodicClock(; dt = nothing, phase = 0.0) = PeriodicClock(dt, phase)
end

# for backwards compatibility
# Re-export for backwards compatibility
const TimeDomain = Clocks.Type
using .Clocks: ContinuousClock, PeriodicClock, SolverStepClock
const Continuous = ContinuousClock()
(clock::TimeDomain)() = clock

Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d)

"""
Clock(dt)
Clock()

The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will
be inferred (if possible).
"""
# These will be overridden by the extension with @match-based versions
Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase)
Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase)
Clock(; phase = 0.0) = PeriodicClock(nothing, phase)

@doc """
SolverStepClock

A clock that ticks at each solver step (sometimes referred to as "continuous sample time").
This clock **does generally not have equidistant tick intervals**, instead, the tick
interval depends on the adaptive step-size selection of the continuous solver, as well as
any continuous event handling. If adaptivity of the solver is turned off and there are no
continuous events, the tick interval will be given by the fixed solver time step `dt`.

Due to possibly non-equidistant tick intervals, this clock should typically not be used with
discrete-time systems that assume a fixed sample time, such as PID controllers and digital
filters.
""" SolverStepClock

isclock(c::TimeDomain) = @match c begin
PeriodicClock() => true
_ => false
end

issolverstepclock(c::TimeDomain) = @match c begin
SolverStepClock() => true
_ => false
end

iscontinuous(c::TimeDomain) = @match c begin
ContinuousClock() => true
_ => false
end

isclock(c::TimeDomain) = c isa PeriodicClock
issolverstepclock(c::TimeDomain) = c isa SolverStepClock
iscontinuous(c::TimeDomain) = c isa ContinuousClock
is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c)

# workaround for https://github.com/Roger-luo/Moshi.jl/issues/43
# Fallbacks for non-TimeDomain types
isclock(::Any) = false
issolverstepclock(::Any) = false
iscontinuous(::Any) = false
is_discrete_time_domain(::Any) = false

function first_clock_tick_time(c, t0)
@match c begin
PeriodicClock(dt) => ceil(t0 / dt) * dt
SolverStepClock() => t0
ContinuousClock() => error("ContinuousClock() is not a discrete clock")
if c isa PeriodicClock
dt = c.dt
return ceil(t0 / dt) * dt
elseif c isa SolverStepClock
return t0
elseif c isa ContinuousClock
error("ContinuousClock() is not a discrete clock")
else
error("Unknown clock type: $(typeof(c))")
end
end

Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d)
(clock::TimeDomain)() = clock

struct IndexedClock{I}
clock::TimeDomain
idx::I
Expand All @@ -81,14 +75,17 @@ Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx)
function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution)
c = ic.clock

return @match c begin
PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
SolverStepClock() => begin
ssc_idx = findfirst(eachindex(sol.discretes)) do i
!isa(sol.discretes[i].t, AbstractRange)
end
sol.discretes[ssc_idx].t[ic.idx]
if c isa PeriodicClock
dt = c.dt
return ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
elseif c isa SolverStepClock
ssc_idx = findfirst(eachindex(sol.discretes)) do i
!isa(sol.discretes[i].t, AbstractRange)
end
ContinuousClock() => sol.t[ic.idx]
return sol.discretes[ssc_idx].t[ic.idx]
elseif c isa ContinuousClock
return sol.t[ic.idx]
else
error("Unknown clock type: $(typeof(c))")
end
end
Loading