diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d91659c0d7..aa02a5fb73 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -3090,7 +3090,7 @@ function process_parameter_dependencies(pdeps, ps) end for p in pdeps] end - lhss = BasicSymbolic[] + lhss = [] for p in pdeps if !isparameter(p.lhs) error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).") @@ -3101,6 +3101,7 @@ function process_parameter_dependencies(pdeps, ps) end push!(lhss, p.lhs) end + lhss = map(identity, lhss) pdeps = topsort_equations(pdeps, union(ps, lhss)) ps = filter(ps) do p !any(isequal(p), lhss) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 90ca3eb781..00f7837407 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -49,7 +49,7 @@ struct IndexCache constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap observed_syms::Set{BasicSymbolic} - dependent_pars::Set{BasicSymbolic} + dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate constant_buffer_sizes::Vector{BufferTemplate} @@ -275,7 +275,8 @@ function IndexCache(sys::AbstractSystem) end end - dependent_pars = Set{BasicSymbolic}() + dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}() + for eq in parameter_dependencies(sys) sym = eq.lhs ttsym = default_toterm(sym) diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 034c27041e..9cfd4ca5c5 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -177,6 +177,29 @@ end @test SciMLBase.successful_retcode(sol) end +struct CallableFoo + p::Any +end + +@register_symbolic CallableFoo(x) + +(f::CallableFoo)(x) = f.p + x + +@testset "callable parameters" begin + @variables y(t) = 1 + @parameters p=2 (i::CallableFoo)(..) + + eqs = [D(y) ~ i(t) + p] + @named model = ODESystem(eqs, t, [y], [p, i]; + parameter_dependencies = [i ~ CallableFoo(p)]) + sys = structural_simplify(model) + + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + + @test SciMLBase.successful_retcode(sol) +end + @testset "Clock system" begin dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t)