Skip to content

Commit 36b41f6

Browse files
committed
fix h, doc, rename
1 parent 42bc4e1 commit 36b41f6

File tree

9 files changed

+58
-22
lines changed

9 files changed

+58
-22
lines changed

include/amici/model_dae.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ class Model_DAE : public Model {
446446
* @param x Vector with the states
447447
* @param p parameter vector
448448
* @param k constants vector
449-
* @param h heavyside vector
449+
* @param h Heaviside vector
450450
* @param dx Vector with the derivative states
451451
* @param w vector with helper variables
452452
*/

models/model_nested_events_py/h.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
#define injection h[0]
2-
#define Heaviside_1 h[1]
3-
#define Heaviside_2 h[2]
1+
#define Heaviside_1 h[0]
2+
#define Heaviside_2 h[1]
3+
#define injection h[2]

python/sdist/amici/de_model.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import re
99
from collections.abc import Callable, Sequence
10+
from contextlib import suppress
1011
from itertools import chain
1112
from operator import itemgetter
1213
from typing import TYPE_CHECKING
@@ -651,9 +652,12 @@ def num_events_solver(self) -> int:
651652
:return:
652653
number of event symbols (length of the root vector in AMICI)
653654
"""
654-
constant_syms = self._static_symbols(["k", "p", "w"])
655+
# TODO(performance): we could include constant `x` here as well
656+
# (dx/dt = 0 AND not target of event assignments)
657+
# this will require passing `x` to `fexplicit_roots`
658+
static_syms = self._static_symbols(["k", "p", "w"])
655659
return sum(
656-
not event.has_explicit_trigger_times(constant_syms)
660+
not event.has_explicit_trigger_times(static_syms)
657661
for event in self.events()
658662
)
659663

@@ -1137,6 +1141,7 @@ def generate_basic_variables(self) -> None:
11371141
Generates the symbolic identifiers for all variables in
11381142
``DEModel._variable_prototype``
11391143
"""
1144+
self.parse_events()
11401145
self._reorder_events()
11411146

11421147
for var in self._variable_prototype:
@@ -1163,11 +1168,6 @@ def parse_events(self) -> None:
11631168
for expr in self._expressions:
11641169
expr.set_val(self._process_heavisides(expr.get_val(), roots))
11651170

1166-
# remove all possible Heavisides from roots, which may arise from
1167-
# the substitution of `'w'` in `_collect_heaviside_roots`
1168-
for root in roots:
1169-
root.set_val(self._process_heavisides(root.get_val(), roots))
1170-
11711171
# Now add the found roots to the model components
11721172
for root in roots:
11731173
# skip roots of SBML events, as these have already been added
@@ -1181,20 +1181,57 @@ def _reorder_events(self) -> None:
11811181
Re-order events - first those that require root tracking,
11821182
then the others.
11831183
"""
1184-
constant_syms = self._static_symbols(["k", "p", "w"])
1184+
# Currently, the C++ simulations relies on the order of events:
1185+
# those that require numerical root-finding must come first, then
1186+
# those with explicit trigger times that don't depend on dynamic
1187+
# variables.
1188+
# TODO: This re-ordering here is a bit ugly, because we already need
1189+
# to generate certain model equations to perform this ordering.
1190+
# Ideally, we'd split froot into explicit and implicit parts during
1191+
# code generation instead (as already done for jax models).
1192+
static_syms = self._static_symbols(["k", "p", "w"])
1193+
1194+
# ensure that we don't have computed any root-related symbols/equations
1195+
# yet, because the re-ordering might invalidate them
1196+
# check after `self._static_symbols` which itself generates certain
1197+
# equations
1198+
if (
1199+
generated := set(self._syms)
1200+
| set(self._eqs)
1201+
| set(self._sparsesyms)
1202+
| set(self._sparseeqs)
1203+
) and (
1204+
"root" in generated
1205+
or any(
1206+
name.startswith("droot") or name.endswith("droot")
1207+
for name in generated
1208+
)
1209+
):
1210+
raise AssertionError(
1211+
"This function must be called before computing any "
1212+
"root-related symbols/equations. "
1213+
"The following symbols/equations are already "
1214+
f"generated: {generated}"
1215+
)
1216+
11851217
self._events = list(
11861218
chain(
11871219
itertools.filterfalse(
1188-
lambda e: e.has_explicit_trigger_times(constant_syms),
1220+
lambda e: e.has_explicit_trigger_times(static_syms),
11891221
self._events,
11901222
),
11911223
filter(
1192-
lambda e: e.has_explicit_trigger_times(constant_syms),
1224+
lambda e: e.has_explicit_trigger_times(static_syms),
11931225
self._events,
11941226
),
11951227
)
11961228
)
11971229

1230+
# regenerate after re-ordering
1231+
with suppress(KeyError):
1232+
del self._syms["h"]
1233+
self.sym("h")
1234+
11981235
def get_appearance_counts(self, idxs: list[int]) -> list[int]:
11991236
"""
12001237
Counts how often a state appears in the time derivative of

python/sdist/amici/exporters/sundials/de_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,14 +914,14 @@ def _get_create_splines_body(self):
914914
def _get_explicit_roots_body(self) -> list[str]:
915915
events = self.model.events()
916916
lines = []
917-
constant_syms = self.model._static_symbols(["k", "p", "w"])
917+
static_syms = self.model._static_symbols(["k", "p", "w"])
918918

919919
for event_idx, event in enumerate(events):
920920
if not (
921921
tigger_times := {
922922
tt
923923
for tt in event.get_trigger_times()
924-
if tt.free_symbols.issubset(constant_syms)
924+
if tt.free_symbols.issubset(static_syms)
925925
}
926926
):
927927
continue

python/sdist/amici/importers/pysb/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,6 @@ def ode_model_from_pysb_importer(
398398

399399
_process_stoichiometric_matrix(model, ode, fixed_parameters)
400400

401-
ode.parse_events()
402401
ode.generate_basic_variables()
403402

404403
return ode

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,7 @@ def _build_ode_model(
700700
if hybridization:
701701
ode_model._process_hybridization(hybridization)
702702

703-
ode_model.parse_events()
704703
# substitute SBML-rateOf constructs
705-
# must be done after parse_events, but before generate_basic_variables
706704
self._process_sbml_rate_of(ode_model)
707705

708706
# fill in 'self._sym' based on prototypes and components in ode_model

python/tests/test_pysb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def test_names_and_ids(pysb_example_presimulation_module):
328328

329329

330330
@skip_on_valgrind
331-
def test_heavyside_and_special_symbols():
331+
def test_heaviside_and_special_symbols():
332332
pysb.SelfExporter.cleanup() # reset pysb
333333
pysb.SelfExporter.do_export = True
334334

src/model_dae.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ void Model_DAE::froot(
9595
) {
9696
std::ranges::fill(root, 0.0);
9797
auto const x_pos = compute_x_pos(x);
98-
// TODO too costly? only dynamic expressions?
98+
// TODO(performance) only dynamic expressions? consider storing a flag
99+
// for whether froot actually depends on `w`.
99100
fw(t, N_VGetArrayPointerConst(x_pos));
100101
froot(
101102
root.data(), t, N_VGetArrayPointerConst(x_pos),

src/model_ode.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ void Model_ODE::froot(
8484
realtype const t, const_N_Vector x, gsl::span<realtype> root
8585
) {
8686
auto const x_pos = compute_x_pos(x);
87-
// TODO too costly? only dynamic expressions?
87+
// TODO(performance) only dynamic expressions? consider storing a flag
88+
// for whether froot actually depends on `w`.
8889
fw(t, N_VGetArrayPointerConst(x_pos));
8990
std::ranges::fill(root, 0.0);
9091
froot(

0 commit comments

Comments
 (0)