diff --git a/Project.toml b/Project.toml index 92b029a3a..9cd938462 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -35,7 +36,6 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -47,6 +47,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" SciMLBaseChainRulesCoreExt = "ChainRulesCore" SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" +SciMLBaseMoshiExt = "Moshi" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" diff --git a/ext/SciMLBaseMLStyleExt.jl b/ext/SciMLBaseMLStyleExt.jl index 8b255f020..09b120911 100644 --- a/ext/SciMLBaseMLStyleExt.jl +++ b/ext/SciMLBaseMLStyleExt.jl @@ -11,7 +11,7 @@ module SciMLBaseMLStyleExt using SciMLBase: TimeDomain, ContinuousClock, SolverStepClock, PeriodicClock using MLStyle: MLStyle using MLStyle.AbstractPatterns: literal, wildcard, PComp, BasicPatterns, decons -using Moshi.Data: isa_variant +# Removed Moshi dependency - using basic struct detection instead # This makes Singletons also work without parentheses in matches MLStyle.is_enum(::Type{ContinuousClock}) = true @@ -26,7 +26,7 @@ function MLStyle.pattern_uncall(::Type{SolverStepClock}, self::Function, _, _, _ end function periodic_clock_pattern(c) - if c isa TimeDomain && isa_variant(c, PeriodicClock) + if c isa PeriodicClock (c.dt, c.phase) else # These values are used in match results, but they shouldn't. diff --git a/ext/SciMLBaseMoshiExt.jl b/ext/SciMLBaseMoshiExt.jl new file mode 100644 index 000000000..2f7ac5681 --- /dev/null +++ b/ext/SciMLBaseMoshiExt.jl @@ -0,0 +1,69 @@ +module SciMLBaseMoshiExt + +using SciMLBase +using Moshi.Data: @data +using Moshi.Match: @match + +# When Moshi is available, override the basic implementations with @match-based versions +# This provides the enhanced pattern matching functionality + +function SciMLBase.isclock(c::SciMLBase.TimeDomain) + @match c begin + SciMLBase.PeriodicClock() => true + _ => false + end +end + +function SciMLBase.issolverstepclock(c::SciMLBase.TimeDomain) + @match c begin + SciMLBase.SolverStepClock() => true + _ => false + end +end + +function SciMLBase.iscontinuous(c::SciMLBase.TimeDomain) + @match c begin + SciMLBase.ContinuousClock() => true + _ => false + end +end + +function SciMLBase.first_clock_tick_time(c::SciMLBase.TimeDomain, t0) + @match c begin + SciMLBase.PeriodicClock(dt) => ceil(t0 / dt) * dt + SciMLBase.SolverStepClock() => t0 + SciMLBase.ContinuousClock() => error("ContinuousClock() is not a discrete clock") + end +end + +function SciMLBase.canonicalize_indexed_clock(ic::SciMLBase.IndexedClock, sol::SciMLBase.AbstractTimeseriesSolution) + c = ic.clock + + return @match c begin + SciMLBase.PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + SciMLBase.SolverStepClock() => begin + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + sol.discretes[ssc_idx].t[ic.idx] + end + SciMLBase.ContinuousClock() => sol.t[ic.idx] + end +end + +# Also define Moshi-based types for users who want the original @data experience +@data MoshiClocks begin + MoshiContinuousClock + struct MoshiPeriodicClock + dt::Union{Nothing, Float64, Rational{Int}} + phase::Float64 = 0.0 + end + MoshiSolverStepClock +end + +# Convenience constructors for the Moshi types +MoshiClock(dt::Union{<:Rational, Float64}; phase = 0.0) = MoshiPeriodicClock(dt, phase) +MoshiClock(dt; phase = 0.0) = MoshiPeriodicClock(convert(Float64, dt), phase) +MoshiClock(; phase = 0.0) = MoshiPeriodicClock(nothing, phase) + +end \ No newline at end of file diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d..b2770345d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -23,8 +23,6 @@ import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert -using Moshi.Data: @data -using Moshi.Match: @match import StaticArraysCore import Adapt: adapt_structure, adapt @@ -862,7 +860,7 @@ export step!, deleteat!, addat!, get_tmp_cache, export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback -export Clocks, TimeDomain, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous +export Clocks, TimeDomain, Clock, Continuous, ContinuousClock, PeriodicClock, SolverStepClock, IndexedClock, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous, first_clock_tick_time, canonicalize_indexed_clock export ODEAliasSpecifier, LinearAliasSpecifier diff --git a/src/clock.jl b/src/clock.jl index a417ebea4..26aa53432 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,76 +1,70 @@ -@data Clocks begin - ContinuousClock - struct PeriodicClock +# Clock functionality requires Moshi extension - original implementation was: +# +# @data Clocks begin +# ContinuousClock +# struct PeriodicClock +# dt::Union{Nothing, Float64, Rational{Int}} +# phase::Float64 = 0.0 +# end +# SolverStepClock +# end +# +# This is recreated by the SciMLBaseMoshiExt when Moshi is loaded + +# Create a module with the same structure as the original, but with stub implementations +module Clocks + abstract type Type end + + struct ContinuousClock <: Type end + struct SolverStepClock <: Type end + + struct PeriodicClock <: Type dt::Union{Nothing, Float64, Rational{Int}} - phase::Float64 = 0.0 + phase::Float64 + PeriodicClock(dt, phase = 0.0) = new(dt, phase) end - SolverStepClock + + # Keyword constructor + PeriodicClock(; dt = nothing, phase = 0.0) = PeriodicClock(dt, phase) end -# for backwards compatibility +# Re-export for backwards compatibility const TimeDomain = Clocks.Type using .Clocks: ContinuousClock, PeriodicClock, SolverStepClock const Continuous = ContinuousClock() -(clock::TimeDomain)() = clock - -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) - -""" - Clock(dt) - Clock() -The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will -be inferred (if possible). -""" +# These will be overridden by the extension with @match-based versions Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) Clock(; phase = 0.0) = PeriodicClock(nothing, phase) -@doc """ - SolverStepClock - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). -This clock **does generally not have equidistant tick intervals**, instead, the tick -interval depends on the adaptive step-size selection of the continuous solver, as well as -any continuous event handling. If adaptivity of the solver is turned off and there are no -continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with -discrete-time systems that assume a fixed sample time, such as PID controllers and digital -filters. -""" SolverStepClock - -isclock(c::TimeDomain) = @match c begin - PeriodicClock() => true - _ => false -end - -issolverstepclock(c::TimeDomain) = @match c begin - SolverStepClock() => true - _ => false -end - -iscontinuous(c::TimeDomain) = @match c begin - ContinuousClock() => true - _ => false -end - +isclock(c::TimeDomain) = c isa PeriodicClock +issolverstepclock(c::TimeDomain) = c isa SolverStepClock +iscontinuous(c::TimeDomain) = c isa ContinuousClock is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c) -# workaround for https://github.com/Roger-luo/Moshi.jl/issues/43 +# Fallbacks for non-TimeDomain types isclock(::Any) = false issolverstepclock(::Any) = false iscontinuous(::Any) = false is_discrete_time_domain(::Any) = false function first_clock_tick_time(c, t0) - @match c begin - PeriodicClock(dt) => ceil(t0 / dt) * dt - SolverStepClock() => t0 - ContinuousClock() => error("ContinuousClock() is not a discrete clock") + if c isa PeriodicClock + dt = c.dt + return ceil(t0 / dt) * dt + elseif c isa SolverStepClock + return t0 + elseif c isa ContinuousClock + error("ContinuousClock() is not a discrete clock") + else + error("Unknown clock type: $(typeof(c))") end end +Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) +(clock::TimeDomain)() = clock + struct IndexedClock{I} clock::TimeDomain idx::I @@ -81,14 +75,17 @@ Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) c = ic.clock - return @match c begin - PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt - SolverStepClock() => begin - ssc_idx = findfirst(eachindex(sol.discretes)) do i - !isa(sol.discretes[i].t, AbstractRange) - end - sol.discretes[ssc_idx].t[ic.idx] + if c isa PeriodicClock + dt = c.dt + return ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + elseif c isa SolverStepClock + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) end - ContinuousClock() => sol.t[ic.idx] + return sol.discretes[ssc_idx].t[ic.idx] + elseif c isa ContinuousClock + return sol.t[ic.idx] + else + error("Unknown clock type: $(typeof(c))") end end