Skip to content

Commit 0b8f906

Browse files
Fix Basis compatibility with ODEProblem (#552)
This commit fixes the issue where using a Basis object obtained from DMD or SINDy as the ODE function in ODEProblem would throw an error about the 'observed' field not existing. Changes: - Added ModelingToolkit interface methods for Basis in new file src/basis/modelingtoolkit_interface.jl - Implemented required AbstractSystem interface methods: equations, unknowns, parameters, get_observed, get_iv, nameof - Override getproperty to properly handle :observed field access - Override show method to fix display error related to namespacing flag - Added conditional definitions for namespacing-related methods to maintain compatibility across ModelingToolkit versions Fixes #552 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent c5e4697 commit 0b8f906

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/DataDrivenDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ struct ZeroDataDrivenAlgorithm <: AbstractDataDrivenAlgorithm end
9696
include("./basis/build_function.jl")
9797
include("./basis/utils.jl")
9898
include("./basis/type.jl")
99+
include("./basis/modelingtoolkit_interface.jl")
99100
export Basis
100101
export jacobian, dynamics
101102
export implicit_variables, states, controls
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# ModelingToolkit interface methods for Basis
2+
3+
import ModelingToolkit: equations, unknowns, parameters, get_observed, get_iv, nameof
4+
5+
# Equations
6+
function ModelingToolkit.equations(b::AbstractBasis)
7+
return getfield(b, :eqs)
8+
end
9+
10+
# Unknowns (states)
11+
function ModelingToolkit.unknowns(b::AbstractBasis)
12+
return states(b)
13+
end
14+
15+
# Parameters
16+
function ModelingToolkit.parameters(b::AbstractBasis)
17+
return getfield(b, :ps)
18+
end
19+
20+
# Observed
21+
function ModelingToolkit.get_observed(b::AbstractBasis)
22+
return getfield(b, :observed)
23+
end
24+
25+
# Independent variable
26+
function ModelingToolkit.get_iv(b::AbstractBasis)
27+
return getfield(b, :iv)
28+
end
29+
30+
# Name
31+
function ModelingToolkit.nameof(b::AbstractBasis)
32+
return getfield(b, :name)
33+
end
34+
35+
# Define has_namespacing to indicate that Basis doesn't support namespacing
36+
if isdefined(ModelingToolkit, :has_namespacing)
37+
function ModelingToolkit.has_namespacing(::AbstractBasis)
38+
return false
39+
end
40+
end
41+
42+
# Define namespace_expression if it exists
43+
if isdefined(ModelingToolkit, :namespace_expression)
44+
function ModelingToolkit.namespace_expression(b::AbstractBasis)
45+
return false # Basis doesn't support namespacing
46+
end
47+
end
48+
49+
# Override getproperty to handle :observed specially
50+
function Base.getproperty(b::AbstractBasis, name::Symbol)
51+
if name === :observed
52+
return get_observed(b)
53+
else
54+
# Fall back to getfield for direct field access
55+
return getfield(b, name)
56+
end
57+
end
58+
59+
# Override show to avoid ModelingToolkit's display method that needs namespacing
60+
function Base.show(io::IO, ::MIME"text/plain", b::AbstractBasis)
61+
Base.print(io, b)
62+
end

0 commit comments

Comments
 (0)