Skip to content

Commit 4e84df3

Browse files
Merge pull request #2795 from AayushSabharwal/as/obsfn-inputs
fix: fix observed function generation for systems with inputs
2 parents 82bb5af + 6534e2b commit 4e84df3

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,18 @@ function build_explicit_observed_function(sys, ts;
473473
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
474474
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
475475
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
476+
if isempty(ps) && inputs !== nothing
477+
ps = (:EMPTY,)
478+
end
476479
else
477480
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
478481
end
479482
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
480483
if inputs === nothing
481484
args = [dvs, ps..., ivs...]
482485
else
483-
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
486+
inputs = unwrap.(inputs)
487+
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
484488
args = [dvs, ipts, ps..., ivs...]
485489
end
486490
pre = get_postprocess_fbody(sys)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,18 +306,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
306306
_jac = nothing
307307
end
308308

309-
observedfun = let sys = sys, dict = Dict()
310-
function generated_observed(obsvar, u, p)
311-
obs = get!(dict, value(obsvar)) do
312-
build_explicit_observed_function(sys, obsvar)
313-
end
314-
if p isa MTKParameters
315-
obs(u, p...)
316-
else
317-
obs(u, p)
318-
end
319-
end
320-
end
309+
observedfun = ObservedFunctionCache(sys)
321310

322311
NonlinearFunction{iip}(f,
323312
sys = sys,

test/input_output_handling.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,16 @@ matrices, ssys = linearize(augmented_sys,
378378
# P = ss(A,B,C,0)
379379
# G = ss(matrices...)
380380
# @test sminreal(G[1, 3]) ≈ sminreal(P[1,1])*dist
381+
382+
@testset "Observed functions with inputs" begin
383+
@variables x(t)=0 u(t)=0 [input = true]
384+
eqs = [
385+
D(x) ~ -x + u
386+
]
387+
388+
@named sys = ODESystem(eqs, t)
389+
(; io_sys,) = ModelingToolkit.generate_control_function(sys, simplify = true)
390+
obsfn = ModelingToolkit.build_explicit_observed_function(
391+
io_sys, [x + u * t]; inputs = [u])
392+
@test obsfn([1.0], [2.0], nothing, 3.0) == [7.0]
393+
end

0 commit comments

Comments
 (0)