diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 1c4e780e64..2a5bc1a728 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -943,6 +943,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys - `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist. - `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function. - `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function. +- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use + it to calculate all delayed variables. ## Returns @@ -981,7 +983,8 @@ function build_explicit_observed_function(sys, ts; op = Operator, throw = true, cse = true, - mkarray = nothing) + mkarray = nothing, + wrap_delays = is_dde(sys)) # TODO: cleanup is_tuple = ts isa Tuple if is_tuple @@ -1068,14 +1071,15 @@ function build_explicit_observed_function(sys, ts; p_end = length(dvs) + length(inputs) + length(ps) fns = build_function_wrapper( sys, ts, args...; p_start, p_end, filter_observed = obsfilter, - output_type, mkarray, try_namespaced = true, expression = Val{true}, cse) + output_type, mkarray, try_namespaced = true, expression = Val{true}, cse, + wrap_delays) if fns isa Tuple if expression return return_inplace ? fns : fns[1] end oop, iip = eval_or_rgf.(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}( + p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( oop, iip) return return_inplace ? (f, f) : f else @@ -1084,7 +1088,7 @@ function build_explicit_observed_function(sys, ts; end f = eval_or_rgf(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}( + p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( f, nothing) return f end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 61d2292925..2016b1efd8 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -646,30 +646,40 @@ struct ReconstructInitializeprob{GP, GU} ugetter::GU end +""" + $(TYPEDEF) + +A wrapper over an observed function which allows calling it on a problem-like object. +`TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if +`false`). +""" +struct ObservedWrapper{TD, F} + f::F +end + +ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f) + +function (ow::ObservedWrapper{true})(prob) + ow.f(state_values(prob), parameter_values(prob), current_time(prob)) +end + +function (ow::ObservedWrapper{false})(prob) + ow.f(state_values(prob), parameter_values(prob)) +end + """ $(TYPEDSIGNATURES) Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter -function by splitting `syms` into contiguous buffers where the getter of each buffer -is type-stable and constructing a function that calls and concatenates the results. -""" -function concrete_getu(indp, syms::AbstractVector) - # a list of contiguous buffer - split_syms = [Any[syms[1]]] - # the type of the getter of the last buffer - current = typeof(getu(indp, syms[1])) - for sym in syms[2:end] - getter = getu(indp, sym) - if typeof(getter) != current - # if types don't match, build a new buffer - push!(split_syms, []) - current = typeof(getter) - end - push!(split_syms[end], sym) - end - split_syms = Tuple(split_syms) - # the getter is now type-stable, and we can vcat it to get the full buffer - return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) +function. + +Note that the getter ONLY works for problem-like objects, since it generates an observed +function. It does NOT work for solutions. +""" +Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector) + @nospecialize + obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false) + return ObservedWrapper{is_time_dependent(indp)}(obsfn) end """ diff --git a/src/utils.jl b/src/utils.jl index dd6971dbe1..2b8ec4a7a0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -830,7 +830,7 @@ Keyword arguments: `available_vars` will not be searched for in the observed equations. """ function observed_equations_used_by(sys::AbstractSystem, exprs; - involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = []) + involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = []) obsvars = getproperty.(obs, :lhs) graph = observed_dependency_graph(obs) if !(available_vars isa Set) diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 2efc9781cf..7e9cbdd740 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -27,13 +27,11 @@ sol = solve(prob, Tsit5()) mtkparams = parameter_values(prob) new_p = rand(14) -@test_broken begin - gs = gradient(new_p) do new_p - new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p) - new_prob = remake(prob, p = new_params) - new_sol = solve(new_prob, Tsit5()) - sum(new_sol) - end +gs = gradient(new_p) do new_p + new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p) + new_prob = remake(prob, p = new_params) + new_sol = solve(new_prob, Tsit5()) + sum(new_sol) end @testset "Issue#2997" begin diff --git a/test/odesystem.jl b/test/odesystem.jl index dc8f3f8349..28bc1d7db0 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -6,11 +6,14 @@ using OrdinaryDiffEq, Sundials using DiffEqBase, SparseArrays using StaticArrays using Test -using SymbolicUtils: issym +using SymbolicUtils.Code +using SymbolicUtils: Sym, issym using ForwardDiff using ModelingToolkit: value using ModelingToolkit: t_nounits as t, D_nounits as D +using Symbolics using Symbolics: unwrap +using DiffEqBase: isinplace # Define some variables @parameters σ ρ β @@ -505,13 +508,6 @@ sys = complete(sys) @test_throws Any ODEFunction(sys) @testset "Preface tests" begin - using OrdinaryDiffEq - using Symbolics - using DiffEqBase: isinplace - using ModelingToolkit - using SymbolicUtils.Code - using SymbolicUtils: Sym - c = [0] function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64) c .= [c[1] + 1] @@ -554,7 +550,9 @@ sys = complete(sys) @named sys = System(eqs, t, us, ps; defaults = defs, preface = preface) sys = complete(sys) - prob = ODEProblem(sys, [], (0.0, 1.0)) + # don't build initializeprob because it will use preface in other functions and + # affect `c` + prob = ODEProblem(sys, [], (0.0, 1.0); build_initializeprob = false) sol = solve(prob, Euler(); dt = 0.1) @test c[1] == length(sol)