Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ MATLABDiffEq.jl is simply a solver on the DiffEq common interface, so for detail
However, the only options implemented are those for error calculations
(`timeseries_errors`), `saveat`, and tolerances.

### Type Requirements

Since this package sends data to MATLAB for computation, it only supports types
that MATLAB can handle:

- **Supported types:** `Float64`, integers (`Int64`, etc.), and `Complex{Float64}`
- **Not supported:** `BigFloat`, `Float32`, GPU arrays (`CuArray`, `JLArray`), or other custom array types

If you need arbitrary precision or GPU computing, use the native Julia solvers
from [DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) instead.

Note that the algorithms are defined to have the same name as the MATLAB algorithms,
but are not exported. Thus to use `ode45`, you would specify the algorithm as
`MATLABDiffEq.ode45()`.
Expand Down
44 changes: 44 additions & 0 deletions src/MATLABDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,39 @@ using Reexport
using MATLAB, ModelingToolkit
using PrecompileTools

# MATLAB only supports Float64 arrays. Check if a type is MATLAB-compatible.
_is_matlab_compatible_eltype(::Type{Float64}) = true
_is_matlab_compatible_eltype(::Type{<:Integer}) = true # MATLAB can convert integers
_is_matlab_compatible_eltype(::Type{<:Complex{Float64}}) = true
_is_matlab_compatible_eltype(::Type) = false

function _check_matlab_compatible(u0, tspan)
T = eltype(u0)
if !_is_matlab_compatible_eltype(T)
throw(ArgumentError(
"MATLABDiffEq.jl requires Float64-compatible element types. " *
"Got eltype(u0) = $T. MATLAB does not support arbitrary precision " *
"(BigFloat) or GPU arrays (JLArrays, CuArrays). Please convert your " *
"initial conditions to Float64: u0 = Float64.(u0)"
))
end
tT = eltype(tspan)
if !_is_matlab_compatible_eltype(tT)
throw(ArgumentError(
"MATLABDiffEq.jl requires Float64-compatible time span types. " *
"Got eltype(tspan) = $tT. MATLAB does not support arbitrary precision " *
"(BigFloat). Please use Float64 for tspan: tspan = Float64.(tspan)"
))
end
# Check that the array type itself is a standard Julia array
if !(u0 isa Array || u0 isa Number)
@warn "MATLABDiffEq.jl works best with standard Julia Arrays. " *
"Got $(typeof(u0)). The array will be converted to a standard Array " *
"before being sent to MATLAB."
end
return nothing
end

# Handle ModelingToolkit API changes: states -> unknowns
if isdefined(ModelingToolkit, :unknowns)
const mtk_states = ModelingToolkit.unknowns
Expand Down Expand Up @@ -35,6 +68,9 @@ function DiffEqBase.__solve(
callback = nothing,
kwargs...
) where {uType, tupType, isinplace, AlgType <: MATLABAlgorithm}
# Validate that input types are MATLAB-compatible
_check_matlab_compatible(prob.u0, prob.tspan)

tType = eltype(tupType)

if prob.tspan[end] - prob.tspan[1] < tType(0)
Expand Down Expand Up @@ -179,6 +215,14 @@ end

# Also precompile with missing keys (common case)
_ = buildDEStats(Dict{String, Any}())

# Precompile type compatibility checks
_ = _is_matlab_compatible_eltype(Float64)
_ = _is_matlab_compatible_eltype(Int64)
_ = _is_matlab_compatible_eltype(Complex{Float64})
_ = _is_matlab_compatible_eltype(BigFloat)
_ = _check_matlab_compatible([1.0, 2.0], (0.0, 1.0))
_ = _check_matlab_compatible(1.0, (0.0, 1.0))
end
end

Expand Down
108 changes: 108 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Interface compatibility tests for MATLABDiffEq.jl
# These tests verify type checking and interface compliance

using Test

@testset "Interface Compatibility" begin
@testset "MATLAB type compatibility checks" begin
# Test _is_matlab_compatible_eltype function
@test MATLABDiffEq._is_matlab_compatible_eltype(Float64) == true
@test MATLABDiffEq._is_matlab_compatible_eltype(Float32) == false
@test MATLABDiffEq._is_matlab_compatible_eltype(Int64) == true
@test MATLABDiffEq._is_matlab_compatible_eltype(Int32) == true
@test MATLABDiffEq._is_matlab_compatible_eltype(Complex{Float64}) == true
@test MATLABDiffEq._is_matlab_compatible_eltype(Complex{Float32}) == false
@test MATLABDiffEq._is_matlab_compatible_eltype(BigFloat) == false
@test MATLABDiffEq._is_matlab_compatible_eltype(BigInt) == false
@test MATLABDiffEq._is_matlab_compatible_eltype(Rational{Int}) == false
end

@testset "_check_matlab_compatible validation" begin
# Valid Float64 inputs should pass
@test MATLABDiffEq._check_matlab_compatible([1.0, 2.0], (0.0, 1.0)) === nothing
@test MATLABDiffEq._check_matlab_compatible(1.0, (0.0, 1.0)) === nothing
@test MATLABDiffEq._check_matlab_compatible([1, 2, 3], (0, 10)) === nothing # Integers OK

# Complex Float64 should pass
@test MATLABDiffEq._check_matlab_compatible([1.0 + 2.0im], (0.0, 1.0)) === nothing

# BigFloat u0 should throw ArgumentError
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
BigFloat[1.0, 2.0], (0.0, 1.0)
)

# BigFloat tspan should throw ArgumentError
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
[1.0, 2.0], (BigFloat(0.0), BigFloat(1.0))
)

# Float32 should throw ArgumentError (MATLAB expects Float64)
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
Float32[1.0, 2.0], (0.0, 1.0)
)
end

@testset "Error messages are helpful" begin
# Test that error messages contain useful information
try
MATLABDiffEq._check_matlab_compatible(BigFloat[1.0], (0.0, 1.0))
@test false # Should not reach here
catch e
@test e isa ArgumentError
@test occursin("BigFloat", e.msg)
@test occursin("Float64", e.msg)
@test occursin("MATLABDiffEq", e.msg)
end

try
MATLABDiffEq._check_matlab_compatible([1.0], (BigFloat(0.0), BigFloat(1.0)))
@test false # Should not reach here
catch e
@test e isa ArgumentError
@test occursin("tspan", lowercase(e.msg))
end
end

@testset "buildDEStats is type-generic" begin
# Test that buildDEStats works with different Dict types
stats1 = Dict{String, Any}("nfevals" => 100, "nsteps" => 50)
result1 = MATLABDiffEq.buildDEStats(stats1)
@test result1.nf == 100
@test result1.naccept == 50

# Test with empty dict
stats2 = Dict{String, Any}()
result2 = MATLABDiffEq.buildDEStats(stats2)
@test result2.nf == 0
@test result2.naccept == 0

# Test with all fields
stats3 = Dict{String, Any}(
"nfevals" => 200,
"nfailed" => 10,
"nsteps" => 190,
"nsolves" => 100,
"npds" => 20,
"ndecomps" => 15
)
result3 = MATLABDiffEq.buildDEStats(stats3)
@test result3.nf == 200
@test result3.nreject == 10
@test result3.naccept == 190
@test result3.nsolve == 100
@test result3.njacs == 20
@test result3.nw == 15
end

@testset "Algorithm structs instantiation" begin
# Test that all algorithm structs can be instantiated
@test MATLABDiffEq.ode23() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode45() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode113() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode23s() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode23t() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode23tb() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode15s() isa MATLABDiffEq.MATLABAlgorithm
@test MATLABDiffEq.ode15i() isa MATLABDiffEq.MATLABAlgorithm
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
using DiffEqBase, MATLABDiffEq, ParameterizedFunctions, Test

# Interface tests - these test type validation without needing MATLAB runtime
include("interface_tests.jl")

# The following tests require MATLAB runtime to be available
# They test the actual ODE solving functionality

f = @ode_def_bare LotkaVolterra begin
dx = a * x - b * x * y
dy = -c * y + d * x * y
Expand Down
Loading