Skip to content

Commit 528cd24

Browse files
committed
Add docstring for build_explicit_observed_function and allow the caller to specify how the oop array is built.
1 parent b52bce7 commit 528cd24

File tree

7 files changed

+37
-7
lines changed

7 files changed

+37
-7
lines changed

dev/DiffEqBase

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 210c83c5ee2283acc441458c185f1f62bb38e426

dev/ODEInterfaceDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 8178c1dd3a9f56c36bbe2e2c876bb8a1612aa5a7

dev/OrdinaryDiffEq

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit ad7891e95d8907b82adb31b5fbaa0d2d7d38a791

dev/SciMLBase

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit c61b13d8f28ac3dc359350d7c64a2a697b569873

dev/StochasticDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit d289cfba3c783bfb2198009937c0591d0743537c

dev/Sundials.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit edf7c9d928191ffdb7a2eaa750b65abf6480c154

src/systems/diffeqs/odesystem.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,32 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...
411411
"""
412412
$(SIGNATURES)
413413
414-
Build the observed function assuming the observed equations are all explicit,
415-
i.e. there are no cycles.
414+
Generates a function that computes the observed value(s) `ts` in the system `sys` assuming that there are no cycles in the equations.
415+
416+
The return value will be either:
417+
* a single function if the input is a scalar or if the input is a Vector but `return_inplace` is false
418+
* the out of place and in-place functions `(ip, oop)` if `return_inplace` is true and the input is a `Vector`
419+
420+
The function(s) will be:
421+
* `RuntimeGeneratedFunction`s by default,
422+
* A Julia `Expr` if `expression` is true,
423+
* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true
424+
425+
The signatures will be of the form `g(...)` with arguments:
426+
* `output` for in-place functions
427+
* `unknowns` if `params_only` is `false`
428+
* `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts`
429+
* `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted
430+
* `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t`
431+
For example, a function `g(op, unknowns, p, inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector, an array of inputs `inputs` is given, and `params_only` is false for a time-dependent system.
432+
433+
Options not otherwise specified are:
434+
* `output_type = Array` the type of the array generated by the out-of-place vector-valued function
435+
* `checkbounds = true` checks bounds if true when destructuring parameters
436+
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
437+
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
438+
* `drop_expr` is deprecated.
439+
* `mkarray`; only used if the output is an array (that is, `!isscalar(ts)`). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
416440
"""
417441
function build_explicit_observed_function(sys, ts;
418442
inputs = nothing,
@@ -426,7 +450,8 @@ function build_explicit_observed_function(sys, ts;
426450
return_inplace = false,
427451
param_only = false,
428452
op = Operator,
429-
throw = true)
453+
throw = true,
454+
mkarray = MakeArray)
430455
if (isscalar = symbolic_type(ts) !== NotSymbolic())
431456
ts = [ts]
432457
end
@@ -571,12 +596,11 @@ function build_explicit_observed_function(sys, ts;
571596
oop_mtkp_wrapper = mtkparams_wrapper
572597
end
573598

599+
output_expr = isscalar ? ts[1] : mkarray(ts, output_type)
574600
# Need to keep old method of building the function since it uses `output_type`,
575601
# which can't be provided to `build_function`
576-
oop_fn = Func(args, [],
577-
pre(Let(obsexprs,
578-
isscalar ? ts[1] : MakeArray(ts, output_type),
579-
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
602+
oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |>
603+
oop_mtkp_wrapper |> toexpr
580604
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
581605

582606
if !isscalar

0 commit comments

Comments
 (0)