Skip to content

Commit 3149f77

Browse files
authored
Merge pull request #1893 from SciML/myb/fastobs
Add one-arg observed dispatch
2 parents 4c30760 + a47662f commit 3149f77

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

src/structural_transformation/codegen.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,20 @@ function build_torn_function(sys;
318318
sol_states = sol_states,
319319
var2assignment = var2assignment
320320

321-
function generated_observed(obsvar, u, p, t)
321+
function generated_observed(obsvar, args...)
322322
obs = get!(dict, value(obsvar)) do
323323
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
324324
is_solver_state_idxs, assignments, deps,
325325
sol_states, var2assignment,
326326
checkbounds = checkbounds)
327327
end
328-
obs(u, p, t)
328+
if args === ()
329+
let obs = obs
330+
(u, p, t) -> obs(u, p, t)
331+
end
332+
else
333+
obs(args...)
334+
end
329335
end
330336
end
331337

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,32 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
328328
obs = observed(sys)
329329
observedfun = if steady_state
330330
let sys = sys, dict = Dict()
331-
function generated_observed(obsvar, u, p, t = Inf)
331+
function generated_observed(obsvar, args...)
332332
obs = get!(dict, value(obsvar)) do
333333
build_explicit_observed_function(sys, obsvar)
334334
end
335-
obs(u, p, t)
335+
if args === ()
336+
let obs = obs
337+
(u, p, t = Inf) -> obs(u, p, t)
338+
end
339+
else
340+
length(args) == 2 ? obs(args..., Inf) : obs(args...)
341+
end
336342
end
337343
end
338344
else
339345
let sys = sys, dict = Dict()
340-
function generated_observed(obsvar, u, p, t)
346+
function generated_observed(obsvar, args...)
341347
obs = get!(dict, value(obsvar)) do
342348
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
343349
end
344-
obs(u, p, t)
350+
if args === ()
351+
let obs = obs
352+
(u, p, t) -> obs(u, p, t)
353+
end
354+
else
355+
obs(args...)
356+
end
345357
end
346358
end
347359
end
@@ -424,11 +436,17 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
424436

425437
obs = observed(sys)
426438
observedfun = let sys = sys, dict = Dict()
427-
function generated_observed(obsvar, u, p, t)
439+
function generated_observed(obsvar, args...)
428440
obs = get!(dict, value(obsvar)) do
429441
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
430442
end
431-
obs(u, p, t)
443+
if args === ()
444+
let obs = obs
445+
(u, p, t) -> obs(u, p, t)
446+
end
447+
else
448+
obs(args...)
449+
end
432450
end
433451
end
434452

test/components.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ end
2121

2222
function check_rc_sol(sol)
2323
rpi = sol[rc_model.resistor.p.i]
24+
rpifun = sol.prob.f.observed(rc_model.resistor.p.i)
25+
@test rpifun.(sol.u, (sol.prob.p,), sol.t) == rpi
2426
@test any(!isequal(rpi[1]), rpi) # test that we don't have a constant system
2527
@test sol[rc_model.resistor.p.i] == sol[resistor.p.i] == sol[capacitor.p.i]
2628
@test sol[rc_model.resistor.n.i] == sol[resistor.n.i] == -sol[capacitor.p.i]

0 commit comments

Comments
 (0)