Skip to content

Commit 18a2a89

Browse files
committed
Add one-arg observed dispatch
1 parent 19a2c0e commit 18a2a89

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

src/structural_transformation/codegen.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,20 @@ function build_torn_function(sys;
318318
sol_states = sol_states,
319319
var2assignment = var2assignment
320320

321-
function generated_observed(obsvar, u, p, t)
321+
function generated_observed(obsvar, args...)
322322
obs = get!(dict, value(obsvar)) do
323323
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
324324
is_solver_state_idxs, assignments, deps,
325325
sol_states, var2assignment,
326326
checkbounds = checkbounds)
327327
end
328-
obs(u, p, t)
328+
if args === ()
329+
let obs = obs
330+
(u, p, t) -> obs(u, p, t)
331+
end
332+
else
333+
obs(args...)
334+
end
329335
end
330336
end
331337

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,32 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
328328
obs = observed(sys)
329329
observedfun = if steady_state
330330
let sys = sys, dict = Dict()
331-
function generated_observed(obsvar, u, p, t = Inf)
331+
function generated_observed(obsvar, args...)
332332
obs = get!(dict, value(obsvar)) do
333333
build_explicit_observed_function(sys, obsvar)
334334
end
335-
obs(u, p, t)
335+
if args === ()
336+
let obs = obs
337+
(u, p, t = Inf) -> obs(u, p, t)
338+
end
339+
else
340+
length(args) == 2 ? obs(args..., Inf) : obs(args...)
341+
end
336342
end
337343
end
338344
else
339345
let sys = sys, dict = Dict()
340-
function generated_observed(obsvar, u, p, t)
346+
function generated_observed(obsvar, args...)
341347
obs = get!(dict, value(obsvar)) do
342348
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
343349
end
344-
obs(u, p, t)
350+
if args === ()
351+
let obs = obs
352+
(u, p, t) -> obs(u, p, t)
353+
end
354+
else
355+
obs(args...)
356+
end
345357
end
346358
end
347359
end
@@ -424,11 +436,17 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
424436

425437
obs = observed(sys)
426438
observedfun = let sys = sys, dict = Dict()
427-
function generated_observed(obsvar, u, p, t)
439+
function generated_observed(obsvar, args...)
428440
obs = get!(dict, value(obsvar)) do
429441
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
430442
end
431-
obs(u, p, t)
443+
if args === ()
444+
let obs = obs
445+
(u, p, t) -> obs(u, p, t)
446+
end
447+
else
448+
obs(args...)
449+
end
432450
end
433451
end
434452

0 commit comments

Comments
 (0)