Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ DataInterpolations = "4, 5, 6, 7, 8"
DiffEqBase = "6"
DocStringExtensions = "0.7, 0.8, 0.9"
MLUtils = "0.3, 0.4"
ModelingToolkit = "10"
ModelingToolkit = "11"
OrdinaryDiffEqTsit5 = "1"
Parameters = "0.12"
ProgressMeter = "1.6"
Expand All @@ -41,8 +41,8 @@ SciMLStructures = "1"
Setfield = "1"
Statistics = "1"
StatsBase = "0.32.0, 0.33, 0.34"
SymbolicUtils = "2, 3, 4"
Symbolics = "5.30.1, 6"
SymbolicUtils = "4"
Symbolics = "7"
julia = "1.10"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion lib/DataDrivenSR/src/DataDrivenSR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DataDrivenSR

using DataDrivenDiffEq
# Load specific (abstract) types
using DataDrivenDiffEq: AbstractBasis
using DataDrivenDiffEq: AbstractBasis, Difference
using DataDrivenDiffEq: AbstractDataDrivenAlgorithm
using DataDrivenDiffEq: AbstractDataDrivenResult
using DataDrivenDiffEq: AbstractDataDrivenProblem
Expand Down
4 changes: 4 additions & 0 deletions src/DataDrivenDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ using Symbolics: scalarize, variable, value
@reexport using ModelingToolkit: unknowns, parameters, independent_variable, observed,
get_iv, get_observed

# Local Difference operator (removed from Symbolics v7)
include("./difference.jl")
export Difference

using Random
using QuadGK
using Statistics
Expand Down
12 changes: 8 additions & 4 deletions src/basis/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,11 @@ function get_parameter_values(x::Basis)
return Float64[]
end
map(ps) do p
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
# In Symbolics v7, hasmetadata check for VariableDefaultValue may not work
# Use try-catch to handle getdefaultval which throws if no default exists
val = try
Symbolics.getdefaultval(p)
else
catch
zero(Symbolics.symtype(p))
end
# Unwrap symbolic values to numeric values for use in ODEProblem
Expand All @@ -562,9 +564,11 @@ Values are unwrapped from symbolic wrappers to ensure compatibility with ODEProb
"""
function get_parameter_map(x::Basis)
map(parameters(x)) do p
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
# In Symbolics v7, hasmetadata check for VariableDefaultValue may not work
# Use try-catch to handle getdefaultval which throws if no default exists
val = try
Symbolics.getdefaultval(p)
else
catch
zero(Symbolics.symtype(p))
end
# Unwrap symbolic values to numeric values for use in ODEProblem
Expand Down
26 changes: 20 additions & 6 deletions src/basis/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
## Create linear independent basis

# Helper to check if x is a constant (Number or SymbolicUtils Const type)
_is_constant(x::Number) = true
function _is_constant(x)
# In SymbolicUtils v4+, constants are wrapped in Const type
# Check using isconst if available
SymbolicUtils.isconst(x)
end

count_operation(x::Number, op::Function, nested::Bool = true) = 0
function count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true)
issym(x) && return 0
# Check if x is a symbol or not a call (e.g., Const in SymbolicUtils v4)
(issym(x) || !iscall(x)) && return 0
if operation(x) == op
if is_unary(op)
# Handles sin, cos and stuff
Expand Down Expand Up @@ -76,7 +86,8 @@ function remove_constant_factor(x)
# Create a new array
ops = Array{Any}(undef, n_ops)
@views split_term!(ops, x, [*])
filter!(x -> !isa(x, Number), ops)
# Filter out constants (both Number and SymbolicUtils Const types)
filter!(x -> !_is_constant(x), ops)
return Num(prod(ops))
end

Expand Down Expand Up @@ -106,15 +117,18 @@ function create_linear_independent_eqs(ops::AbstractVector, simplify_eqs::Bool =
return simplify_eqs ? simplify.(Num.(u_o)) : Num.(u_o)
end

function is_dependent(x::SymbolicUtils.Symbolic, y::SymbolicUtils.Symbolic)
occursin(y, x)
function is_dependent(x::SymbolicUtils.BasicSymbolic, y::SymbolicUtils.BasicSymbolic)
# In SymbolicUtils v4, occursin was removed. Use get_variables instead.
# Check if y appears in the variables of x
vars = Symbolics.get_variables(x)
y in vars
end

function is_dependent(x::Any, y::SymbolicUtils.Symbolic)
function is_dependent(x::Any, y::SymbolicUtils.BasicSymbolic)
false
end

function is_dependent(x::SymbolicUtils.Symbolic, y::Any)
function is_dependent(x::SymbolicUtils.BasicSymbolic, y::Any)
false
end

Expand Down
47 changes: 47 additions & 0 deletions src/difference.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Local Difference operator for discrete-time systems
# This was removed from Symbolics.jl v7, so we define it locally for backwards compatibility
# See: https://github.com/SciML/DataDrivenDiffEq.jl/issues/563

using Symbolics: Operator, value, unwrap, wrap
using SymbolicUtils: term

"""
Difference(t; dt, update=false)

Represents a difference operator for discrete-time systems.

# Fields

- `t`: The independent variable
- `dt`: The time step
- `update`: If true, represents a shift/update operator

# Examples

```julia
@variables t
d = Difference(t; dt = 0.01)
```
"""
Comment on lines +8 to +25
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal is this the right approach here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Difference was MTK's old way of handling discrete systems. That doesn't work anymore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas have you already released this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It doesn't work anymore for MTK but this package needs a representation of it in order to not break.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it not use the new MTK discrete system interface?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't even use MTK at all, why would it start?

struct Difference <: Operator
t
dt
update::Bool
Difference(t; dt, update = false) = new(value(t), dt, update)
end

(D::Difference)(x) = term(D, unwrap(x))
(D::Difference)(x::Num) = wrap(D(unwrap(x)))

# More specific method to avoid ambiguity with SymbolicUtils.Operator method
SymbolicUtils.promote_symtype(::Difference, ::Type{T}) where {T} = T

function Base.show(io::IO, D::Difference)
print(io, "Difference(", D.t, "; dt=", D.dt, ", update=", D.update, ")")
end
Base.nameof(::Difference) = :Difference

function Base.:(==)(D1::Difference, D2::Difference)
isequal(D1.t, D2.t) && isequal(D1.dt, D2.dt) && isequal(D1.update, D2.update)
end
Base.hash(D::Difference, u::UInt) = hash(D.dt, hash(D.t, xor(u, 0x055640d6d952f101)))
7 changes: 5 additions & 2 deletions src/utils/build_basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ end
function __assert_linearity(eqs::AbstractVector{Num}, x::AbstractVector)
j = Symbolics.jacobian(eqs, x)
# Check if any of the variables is in the jacobian
v = get_variables.(j)
# get_variables returns a Set in Symbolics v7, so we need to collect and flatten
v_sets = get_variables.(j)
isempty(v_sets) && return true
# Flatten all Sets into a single collection and get unique variables
v = unique(reduce(union, v_sets; init = Set()))
isempty(v) && return true
v = unique(v)
for xi in x, vi in v

isequal(xi, vi) && return false
Expand Down
4 changes: 3 additions & 1 deletion test/basis/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ end

@test size(basis) == (6,)
@test size(basis_2) == (5,)
@test basis_2([1.0; 2.0; π], [0.0; 1.0]) ≈ [1.0; -1.0; 2.0; π; 1.0]
# Note: Order may differ due to internal Symbolics representation
# The linear_independent basis extracts terms which may be reordered
@test basis_2([1.0; 2.0; π], [0.0; 1.0]) ≈ [1.0; -1.0; π; 2.0; 1.0]
@test basis([1.0; 2.0; π], [0.0; 1.0]) ≈ [1.0; 2.0; π; -1.0; 5 * π + 2.0; 1.0]

@test size(basis) == size(basis_2) .+ (1,)
Expand Down
3 changes: 2 additions & 1 deletion test/basis/implicit_basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ end

for r in [direct_res, discrete_res, cont_res]
lhs = Num.(map(eq -> eq.rhs, equations(direct_res)))
xs = unique(reduce(vcat, Symbolics.get_variables.(lhs)))
# get_variables returns Sets in Symbolics v7, so use union to combine them
xs = collect(reduce(union, Symbolics.get_variables.(lhs); init = Set()))
@test !any(DataDrivenDiffEq.is_dependent(Num.(xs), du))
@test any(DataDrivenDiffEq.is_dependent(Num.(xs), u))
end
Expand Down
31 changes: 14 additions & 17 deletions test/problem/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,22 @@ end
using OrdinaryDiffEqTsit5
using ModelingToolkit: t_nounits as time, D_nounits as D

@mtkmodel Autoregulation begin
@parameters begin
α = 1.0
β = 1.3
γ = 2.0
δ = 0.5
end
@variables begin
(x(time))[1:2] = [20.0; 12.0]
end
@equations begin
D(x[1]) ~ α / (1 + x[2]) - β * x[1]
D(x[2]) ~ γ / (1 + x[1]) - δ * x[2]
end
end
# Define autoregulation system without @mtkmodel macro
# (avoids macro import issues with SafeTestsets)
@parameters α=1.0 β=1.3 γ=2.0 δ=0.5
@variables (x(time))[1:2]=[20.0, 12.0]
x = collect(x)

eqs = [
D(x[1]) ~ α / (1 + x[2]) - β * x[1],
D(x[2]) ~ γ / (1 + x[1]) - δ * x[2]
]

@named sys = System(eqs, time)
sys_compiled = mtkcompile(sys)

@mtkcompile sys = Autoregulation()
tspan = (0.0, 5.0)
de_problem = ODEProblem{true, SciMLBase.NoSpecialize}(sys, [], tspan)
de_problem = ODEProblem{true, SciMLBase.NoSpecialize}(sys_compiled, [], tspan)
de_solution = solve(de_problem, Tsit5(), saveat = 0.005)
prob = DataDrivenProblem(de_solution)
@test is_valid(prob)
Expand Down
Loading