Skip to content

Commit 42cc9cf

Browse files
fix: pass operating point to ImplicitDiscreteProblem in generate_equational_affect
1 parent 58e917f commit 42cc9cf

File tree

8 files changed

+52
-9
lines changed

8 files changed

+52
-9
lines changed

src/problems/daeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)
7373

7474
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
75-
kwargs...)
75+
op, kwargs...)
7676

7777
diffvars = collect_differential_variables(sys)
7878
sts = unknowns(sys)

src/problems/ddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
end
6767

6868
kwargs = process_kwargs(
69-
sys; expression, callback, eval_expression, eval_module, kwargs...)
69+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7070
args = (; f, u0, h, tspan, p)
7171

7272
return maybe_codegen_scimlproblem(expression, DDEProblem{iip}, args; kwargs...)

src/problems/jumpproblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@
8080
end
8181

8282
# handle events, making sure to reset aggregators in the generated affect functions
83-
cbs = process_events(sys; callback, eval_expression, eval_module, reset_jumps = true)
83+
cbs = process_events(
84+
sys; callback, eval_expression, eval_module, op, reset_jumps = true)
8485

8586
if rng !== nothing
8687
kwargs = (; kwargs..., rng)

src/problems/odeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
eval_module, expression, check_compatibility, kwargs...)
7676

7777
kwargs = process_kwargs(
78-
sys; expression, callback, eval_expression, eval_module, kwargs...)
78+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7979

8080
ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
8181
args = (; f, u0, tspan, p, ptype)

src/problems/sddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
end
6969

7070
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
71-
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
71+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, op, kwargs...)
7272

7373
if expression == Val{true}
7474
g = :(f.g)

src/problems/sdeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
8080
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
81-
kwargs...)
81+
op, kwargs...)
8282

8383
args = (; f, u0, tspan, p)
8484
kwargs = (; noise, noise_rate_prototype, kwargs...)

src/systems/callbacks.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,12 +798,27 @@ function add_integrator_header(
798798
expr.body)
799799
end
800800

801+
function default_operating_point(affsys::AffectSystem)
802+
sys = system(affsys)
803+
804+
op = Dict(unknowns(sys) .=> 0.0)
805+
for p in parameters(sys)
806+
T = symtype(p)
807+
if T <: Number
808+
op[p] = false
809+
elseif T <: Array{<:Real} && is_sized_array_symbolic(p)
810+
op[p] = zeros(size(p))
811+
end
812+
end
813+
return op
814+
end
815+
801816
"""
802817
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
803818
"""
804819
function compile_equational_affect(
805820
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false,
806-
eval_expression = false, eval_module = @__MODULE__, kwargs...)
821+
eval_expression = false, eval_module = @__MODULE__, op = default_operating_point(aff), kwargs...)
807822
if aff isa AbstractVector
808823
aff = make_affect(
809824
aff; iv = get_iv(sys), warn_no_algebraic = false)
@@ -872,10 +887,10 @@ function compile_equational_affect(
872887
p_getter = getsym(affsys, ps_to_update)
873888

874889
affprob = ImplicitDiscreteProblem(
875-
affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0],
890+
affsys, op,
876891
(0, 0);
877892
build_initializeprob = false, check_length = false, eval_expression,
878-
eval_module, check_compatibility = false)
893+
eval_module, check_compatibility = false, kwargs...)
879894

880895
function implicit_affect!(integ)
881896
new_u0 = affu_getter(integ)

test/symbolic_events.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,30 @@ end
13481348
@test SciMLBase.successful_retcode(sol)
13491349
@test sol[inner.p][end] 1.0
13501350
end
1351+
1352+
mutable struct ParamTest
1353+
y::Any
1354+
end
1355+
1356+
@testset "callable parameter and symbolic affect" begin
1357+
(pt::ParamTest)(x) = pt.y - x
1358+
1359+
p1 = ParamTest(1)
1360+
tp1 = typeof(p1)
1361+
@parameters (p_1::tp1)(..) = p1
1362+
@parameters p2(t) = 1.0
1363+
@variables x(t) = 0.0
1364+
@variables x2(t)
1365+
event = [0.5] => [p2 ~ Pre(t)]
1366+
1367+
eq = [
1368+
D(x) ~ p2,
1369+
x2 ~ p_1(x)
1370+
]
1371+
@mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event])
1372+
1373+
prob = ODEProblem(sys, [], (0.0, 1.0))
1374+
sol = solve(prob)
1375+
@test SciMLBase.successful_retcode(sol)
1376+
@test sol[x, end]1.0 atol=1e-6
1377+
end

0 commit comments

Comments
 (0)