Skip to content

Commit a2cfee6

Browse files
vyuduAayushSabharwal
authored andcommitted
fix: update callback and jump codegen in JumpProblem
1 parent 0b500c1 commit a2cfee6

File tree

1 file changed

+11
-30
lines changed

1 file changed

+11
-30
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
11
const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
22

3-
# modifies the expression representing an affect function to
4-
# call reset_aggregated_jumps!(integrator).
5-
# assumes iip
6-
function _reset_aggregator!(expr, integrator)
7-
@assert Meta.isexpr(expr, :function)
8-
body = expr.args[end]
9-
body = quote
10-
$body
11-
$reset_aggregated_jumps!($integrator)
12-
end
13-
expr.args[end] = body
14-
return nothing
15-
end
16-
173
"""
184
$(TYPEDEF)
195
@@ -230,8 +216,10 @@ function JumpSystem(eqs, iv, unknowns, ps;
230216
end
231217
end
232218

233-
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
234-
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
219+
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback,
220+
iv = iv, warn_no_algebraic = false)
221+
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback,
222+
iv = iv, warn_no_algebraic = false)
235223

236224
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
237225
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,
@@ -282,15 +270,13 @@ function generate_rate_function(js::JumpSystem, rate)
282270
expression = Val{true})
283271
end
284272

285-
function generate_affect_function(js::JumpSystem, affect, outputidxs)
273+
function generate_affect_function(js::JumpSystem, affect)
286274
consts = collect_constants(affect)
287275
if !isempty(consts) # The SymbolicUtils._build_function method of this case doesn't support postprocess_fbody
288276
csubs = Dict(c => getdefault(c) for c in consts)
289277
affect = substitute(affect, csubs)
290278
end
291-
compile_affect(
292-
affect, nothing, js, unknowns(js), parameters(js); outputidxs = outputidxs,
293-
expression = Val{true}, checkvars = false)
279+
compile_equational_affect(affect, js; expression = Val{true}, checkvars = false)
294280
end
295281

296282
function assemble_vrj(
@@ -299,19 +285,17 @@ function assemble_vrj(
299285
rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
300286
outputvars = (value(affect.lhs) for affect in vrj.affect!)
301287
outputidxs = [unknowntoid[var] for var in outputvars]
302-
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
303-
eval_expression, eval_module)
288+
affect = generate_affect_function(js, vrj.affect!)
304289
VariableRateJump(rate, affect; save_positions = vrj.save_positions)
305290
end
306291

307292
function assemble_vrj_expr(js, vrj, unknowntoid)
308293
rate = generate_rate_function(js, vrj.rate)
309294
outputvars = (value(affect.lhs) for affect in vrj.affect!)
310295
outputidxs = ((unknowntoid[var] for var in outputvars)...,)
311-
affect = generate_affect_function(js, vrj.affect!, outputidxs)
296+
affect = generate_affect_function(js, vrj.affect!)
312297
quote
313298
rate = $rate
314-
315299
affect = $affect
316300
VariableRateJump(rate, affect)
317301
end
@@ -323,19 +307,17 @@ function assemble_crj(
323307
rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
324308
outputvars = (value(affect.lhs) for affect in crj.affect!)
325309
outputidxs = [unknowntoid[var] for var in outputvars]
326-
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
327-
eval_expression, eval_module)
310+
affect = generate_affect_function(js, crj.affect!)
328311
ConstantRateJump(rate, affect)
329312
end
330313

331314
function assemble_crj_expr(js, crj, unknowntoid)
332315
rate = generate_rate_function(js, crj.rate)
333316
outputvars = (value(affect.lhs) for affect in crj.affect!)
334317
outputidxs = ((unknowntoid[var] for var in outputvars)...,)
335-
affect = generate_affect_function(js, crj.affect!, outputidxs)
318+
affect = generate_affect_function(js, crj.affect!)
336319
quote
337320
rate = $rate
338-
339321
affect = $affect
340322
ConstantRateJump(rate, affect)
341323
end
@@ -574,8 +556,7 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
574556
end
575557

576558
# handle events, making sure to reset aggregators in the generated affect functions
577-
cbs = process_events(js; callback, eval_expression, eval_module,
578-
postprocess_affect_expr! = _reset_aggregator!)
559+
cbs = process_events(js; callback, eval_expression, eval_module, reset_jumps = true)
579560

580561
JumpProblem(prob, aggregator, jset; dep_graph = jtoj, vartojumps_map = vtoj,
581562
jumptovars_map = jtov, scale_rates = false, nocopy = true,

0 commit comments

Comments
 (0)