From 28a6ed2b52dec6b7581f3f50981ffc8c17628c6f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Feb 2025 16:31:33 +0530 Subject: [PATCH 1/3] feat: throw error when differentiating registered function with no derivative in `structural_simplify` --- src/structural_transformation/StructuralTransformations.jl | 1 + src/structural_transformation/symbolics_tearing.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 2e600e86e4..4adc817ef8 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -4,6 +4,7 @@ using Setfield: @set!, @set using UnPack: @unpack using Symbolics: unwrap, linear_expansion, fast_substitute +import Symbolics using SymbolicUtils using SymbolicUtils.Code using SymbolicUtils.Rewriters diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index f1cdd7ce9c..6845351008 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -65,7 +65,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...) sys = ts.sys eq = equations(ts)[ieq] - eq = 0 ~ ModelingToolkit.derivative(eq.rhs - eq.lhs, get_iv(sys)) + eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true) push!(equations(ts), eq) # Analyze the new equation and update the graph/solvable_graph # First, copy the previous incidence and add the derivative terms. From 5aceb9e9add2b8549f79489e74ba3b242fc6a5ba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 21:43:53 +0530 Subject: [PATCH 2/3] test: add required derivative, fix initialization in split parameters test --- test/split_parameters.jl | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 74f8d41d73..1052f4ad27 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -51,8 +51,8 @@ end get_value(interp::Interpolator, t) = interp(t) @register_symbolic get_value(interp::Interpolator, t) -# get_value(data, t, dt) = data[round(Int, t / dt + 1)] -# @register_symbolic get_value(data::Vector, t, dt) + +Symbolics.derivative(::typeof(get_value), args::NTuple{2, Any}, ::Val{2}) = 0 function Sampled(; name, interp = Interpolator(Float64[], 0.0)) pars = @parameters begin @@ -68,11 +68,10 @@ function Sampled(; name, interp = Interpolator(Float64[], 0.0)) output.u ~ get_value(interpolator, t) ] - return ODESystem(eqs, t, vars, [interpolator]; name, systems, - defaults = [output.u => interp.data[1]]) + return ODESystem(eqs, t, vars, [interpolator]; name, systems) end -vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 +vars = @variables y(t) dy(t) ddy(t) @named src = Sampled(; interp = Interpolator(x, dt)) @named int = Integrator() @@ -84,11 +83,9 @@ eqs = [y ~ src.output.u @named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) s = complete(sys) sys = structural_simplify(sys) -@test_broken ODEProblem( - sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; tofloat = false) prob = ODEProblem( sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; - tofloat = false, build_initializeprob = false) + tofloat = false) sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success @test sol[y][end] == x[end] From 8c43ac5dd9bde84a35619bd8e4817d6e5d7ffaf4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Mar 2025 16:52:54 +0530 Subject: [PATCH 3/3] build: bump Symbolics compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 92d8f98ab4..a9290b037a 100644 --- a/Project.toml +++ b/Project.toml @@ -151,7 +151,7 @@ StochasticDelayDiffEq = "1.8.1" StochasticDiffEq = "6.72.1" SymbolicIndexingInterface = "0.3.37" SymbolicUtils = "3.14" -Symbolics = "6.29.1" +Symbolics = "6.29.2" URIs = "1" UnPack = "0.1, 1.0" Unitful = "1.1"