Skip to content

Commit f74f184

Browse files
fix: retain nonnumeric parameter dependencies in initialization system
1 parent 74412cd commit f74f184

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,13 @@ function generate_initializesystem(sys::ODESystem;
138138
end
139139

140140
# 5) parameter dependencies become equations, their LHS become unknowns
141+
# non-numeric dependent parameters stay as parameter dependencies
142+
new_parameter_deps = Equation[]
141143
for eq in parameter_dependencies(sys)
142-
is_variable_floatingpoint(eq.lhs) || continue
144+
if !is_variable_floatingpoint(eq.lhs)
145+
push!(new_parameter_deps, eq)
146+
continue
147+
end
143148
varp = tovar(eq.lhs)
144149
paramsubs[eq.lhs] = varp
145150
push!(eqs_ics, eq)
@@ -171,6 +176,7 @@ function generate_initializesystem(sys::ODESystem;
171176
pars;
172177
defaults = defs,
173178
checks = check_units,
179+
parameter_dependencies = new_parameter_deps,
174180
name,
175181
metadata = meta,
176182
kwargs...)

test/initializationsystem.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,3 +815,24 @@ end
815815
prob2 = @test_nowarn remake(prob; u0 = [y => 0.5])
816816
@test is_variable(prob.f.initializeprob, p)
817817
end
818+
819+
struct Multiplier{T}
820+
a::T
821+
b::T
822+
end
823+
824+
function (m::Multiplier)(x, y)
825+
m.a * x + m.b * y
826+
end
827+
828+
@register_symbolic Multiplier(x::Real, y::Real)
829+
830+
@testset "Nonnumeric parameter dependencies are retained" begin
831+
@variables x(t) y(t)
832+
@parameters foo(::Real, ::Real) p
833+
@mtkbuild sys = ODESystem([D(x) ~ t, 0 ~ foo(x, y)], t;
834+
parameter_dependencies = [foo ~ Multiplier(p, 2p)], guesses = [y => -1.0])
835+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
836+
integ = init(prob, Rosenbrock23())
837+
@test integ[y] -0.5
838+
end

0 commit comments

Comments
 (0)