Skip to content

Commit c7e108b

Browse files
committed
Add observed function for NonlinearFunctions
1 parent f1a070b commit c7e108b

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ i.e. there are no cycles.
234234
function build_explicit_observed_function(
235235
sys, syms;
236236
expression=false,
237-
output_type=Array)
237+
output_type=Array,
238+
)
238239

239240
if (isscalar = !(syms isa Vector))
240241
syms = [syms]
@@ -252,13 +253,13 @@ function build_explicit_observed_function(
252253
output[i] = obs[idx].rhs
253254
end
254255

256+
dvs = DestructuredArgs(states(sys))
257+
ps = DestructuredArgs(parameters(sys))
258+
iv = independent_variable(sys)
259+
args = iv === nothing ? [dvs, ps] : [dv, ps, iv]
260+
255261
ex = Func(
256-
[
257-
DestructuredArgs(states(sys))
258-
DestructuredArgs(parameters(sys))
259-
independent_variable(sys)
260-
],
261-
[],
262+
args, [],
262263
Let(
263264
map(eq -> eq.lhseq.rhs, obs[1:maxidx]),
264265
isscalar ? output[1] : MakeArray(output, output_type)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,19 @@ function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sy
143143
_jac = nothing
144144
end
145145

146+
observedfun = let sys = sys, dict = Dict()
147+
function generated_observed(obsvar, u, p)
148+
obs = get!(dict, value(obsvar)) do
149+
build_explicit_observed_function(sys, obsvar)
150+
end
151+
obs(u, p)
152+
end
153+
end
154+
146155
NonlinearFunction{iip}(f,
147156
jac = _jac === nothing ? nothing : _jac,
148157
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
149-
syms = Symbol.(states(sys)))
158+
syms = Symbol.(states(sys)), observed = observedfun)
150159
end
151160

152161
"""

test/reduction.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ nlprob = NonlinearProblem(reducedsys, u0, pp)
200200
reducedsol = solve(nlprob, NewtonRaphson())
201201
residual = fill(100.0, length(states(reducedsys)))
202202
nlprob.f(residual, reducedsol.u, pp)
203+
@test hypot(nlprob.f.observed(u2, reducedsol.u, pp), nlprob.f.observed(u1, reducedsol.u, pp)) * pp reducedsol.u atol=1e-9
204+
203205
@test all(x->abs(x) < 1e-5, residual)
204206

205207
N = 5

0 commit comments

Comments
 (0)