Skip to content

Commit cadaa43

Browse files
committed
implement implicit triggers using fixed parameters check
1 parent 97f69f4 commit cadaa43

File tree

6 files changed

+32
-25
lines changed

6 files changed

+32
-25
lines changed

python/sdist/amici/_symbolic/de_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2684,7 +2684,9 @@ def has_implicit_event_assignments(self) -> bool:
26842684
:return:
26852685
boolean indicating if event assignments with implicit triggers are present
26862686
"""
2687-
return any(event.updates_state and event._implicit_symbols() for event in self._events)
2687+
fixed_symbols = set([k._symbol for k in self._fixed_parameters])
2688+
allowed_symbols = fixed_symbols | {amici_time_symbol}
2689+
return any(event.updates_state and not event.has_explicit_trigger_times(allowed_symbols) for event in self._events)
26882690

26892691
def toposort_expressions(
26902692
self, reorder: bool = True

python/sdist/amici/_symbolic/de_model_components.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,12 @@ def has_explicit_trigger_times(
863863
"""
864864
if allowed_symbols is None:
865865
return len(self._t_root) > 0
866+
867+
if len(self._t_root) == 0:
868+
t = self.get_val()
869+
return t.is_Number or t.free_symbols.issubset(allowed_symbols)
866870

867-
return len(self._t_root) > 0 and all(
871+
return all(
868872
t.is_Number or t.free_symbols.issubset(allowed_symbols)
869873
for t in self._t_root
870874
)
@@ -883,18 +887,6 @@ def get_trigger_times(self) -> set[sp.Expr]:
883887
time points at which the event triggers.
884888
"""
885889
return set(self._t_root)
886-
887-
def _implicit_symbols(self):
888-
"""Get implicit symbols in the event trigger function.
889-
That is, all symbols except time and petab indicator variables.
890-
"""
891-
symbols = [str(s) for s in list(self.get_val().free_symbols)]
892-
implicit_symbols = []
893-
for s in symbols:
894-
if (s.startswith("_petab_") and "indicator" in s) or s == "t":
895-
continue
896-
implicit_symbols.append(s)
897-
return len(implicit_symbols) > 0
898890

899891
@property
900892
def uses_values_from_trigger_time(self) -> bool:

python/sdist/amici/importers/petab/_petab_importer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def _do_import_sbml(self):
353353
model_name=self._module_name,
354354
output_dir=self.output_dir,
355355
observation_model=observation_model,
356+
fixed_parameters=fixed_parameters,
356357
verbose=self._verbose,
357358
# **kwargs,
358359
)

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def sbml2jax(
438438
self,
439439
model_name: str,
440440
output_dir: str | Path = None,
441+
fixed_parameters: Iterable[str] = None,
441442
observation_model: list[MeasurementChannel] = None,
442443
verbose: int | bool = logging.ERROR,
443444
compute_conservation_laws: bool = True,
@@ -465,6 +466,9 @@ def sbml2jax(
465466
:param output_dir:
466467
Directory where the generated model package will be stored.
467468
469+
:param fixed_parameters:
470+
SBML Ids to be excluded from sensitivity analysis
471+
468472
:param observation_model:
469473
The different measurement channels that make up the observation
470474
model, see :class:`amici.importers.utils.MeasurementChannel`.
@@ -513,6 +517,7 @@ def sbml2jax(
513517
set_log_level(logger, verbose)
514518

515519
ode_model = self._build_ode_model(
520+
fixed_parameters=fixed_parameters,
516521
observation_model=observation_model,
517522
verbose=verbose,
518523
compute_conservation_laws=compute_conservation_laws,

python/sdist/amici/jax/jax.template.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
2121
def __init__(self):
2222
self.jax_py_file = Path(__file__).resolve()
2323
self.nns = {TPL_NETS}
24-
self.parameters = TPL_P_VALUES
24+
self.parameters = TPL_ALL_P_VALUES
2525
super().__init__()
2626

2727
def _xdot(self, t, x, args):
2828
p, tcl, h = args
2929

3030
TPL_X_SYMS = x
31-
TPL_P_SYMS = p
31+
TPL_ALL_P_SYMS = p
3232
TPL_TCL_SYMS = tcl
3333
TPL_H_SYMS = h
3434
TPL_W_SYMS = self._w(t, x, p, tcl, h)
@@ -39,7 +39,7 @@ def _xdot(self, t, x, args):
3939

4040
def _w(self, t, x, p, tcl, h):
4141
TPL_X_SYMS = x
42-
TPL_P_SYMS = p
42+
TPL_ALL_P_SYMS = p
4343
TPL_TCL_SYMS = tcl
4444
TPL_H_SYMS = h
4545

@@ -48,7 +48,7 @@ def _w(self, t, x, p, tcl, h):
4848
return TPL_W_RET
4949

5050
def _x0(self, t, p):
51-
TPL_P_SYMS = p
51+
TPL_ALL_P_SYMS = p
5252

5353
TPL_X0_EQ
5454

@@ -71,15 +71,15 @@ def _x_rdata(self, x, tcl):
7171

7272
def _tcl(self, x, p):
7373
TPL_X_RDATA_SYMS = x
74-
TPL_P_SYMS = p
74+
TPL_ALL_P_SYMS = p
7575

7676
TPL_TOTAL_CL_EQ
7777

7878
return TPL_TOTAL_CL_RET
7979

8080
def _y(self, t, x, p, tcl, h, op):
8181
TPL_X_SYMS = x
82-
TPL_P_SYMS = p
82+
TPL_ALL_P_SYMS = p
8383
TPL_W_SYMS = self._w(t, x, p, tcl, h)
8484
TPL_OP_SYMS = op
8585

@@ -88,7 +88,7 @@ def _y(self, t, x, p, tcl, h, op):
8888
return TPL_Y_RET
8989

9090
def _sigmay(self, y, p, np):
91-
TPL_P_SYMS = p
91+
TPL_ALL_P_SYMS = p
9292

9393
TPL_Y_SYMS = y
9494
TPL_NP_SYMS = np
@@ -110,15 +110,15 @@ def _nllh(self, t, x, p, tcl, h, my, iy, op, np):
110110
return TPL_JY_RET.at[iy].get()
111111

112112
def _known_discs(self, p):
113-
TPL_P_SYMS = p
113+
TPL_ALL_P_SYMS = p
114114

115115
return TPL_ROOTS
116116

117117
def _root_cond_fn(self, t, y, args, **_):
118118
p, tcl, h = args
119119

120120
TPL_X_SYMS = y
121-
TPL_P_SYMS = p
121+
TPL_ALL_P_SYMS = p
122122
TPL_TCL_SYMS = tcl
123123
TPL_H_SYMS = h
124124
TPL_W_SYMS = self._w(t, y, p, tcl, h)
@@ -130,7 +130,7 @@ def _root_cond_fn(self, t, y, args, **_):
130130

131131
def _delta_x(self, y, p, tcl):
132132
TPL_X_SYMS = y
133-
TPL_P_SYMS = p
133+
TPL_ALL_P_SYMS = p
134134
TPL_TCL_SYMS = tcl
135135
# FIXME: workaround until state from event time is properly passed
136136
TPL_X_OLD_SYMS = y
@@ -157,7 +157,7 @@ def state_ids(self):
157157

158158
@property
159159
def parameter_ids(self):
160-
return TPL_P_IDS
160+
return TPL_ALL_P_IDS
161161

162162
@property
163163
def expression_ids(self):

python/sdist/amici/jax/ode_export.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def _generate_jax_code(self) -> None:
213213
)
214214
sym_names = (
215215
"p",
216+
"k",
216217
"np",
217218
"op",
218219
"x",
@@ -261,6 +262,9 @@ def _generate_jax_code(self) -> None:
261262
# tuple of variable names (ids as they are unique)
262263
**_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")),
263264
"P_VALUES": _jnp_array_str(self.model.val("p")),
265+
"ALL_P_VALUES": _jnp_array_str(self.model.val("p") + self.model.val("k")),
266+
"ALL_P_IDS": "".join(f'"{s.name}", ' for s in self._get_all_p_syms()),
267+
"ALL_P_SYMS": "".join(f"{s.name}, " for s in self._get_all_p_syms()),
264268
"ROOTS": _jnp_array_str(
265269
{
266270
_print_trigger_root(root)
@@ -297,6 +301,9 @@ def _generate_jax_code(self) -> None:
297301
tpl_data,
298302
)
299303

304+
def _get_all_p_syms(self) -> list[sp.Symbol]:
305+
return list(self.model.sym("p")) + list(self.model.sym("k"))
306+
300307
def _generate_nn_code(self) -> None:
301308
for net_name, net in self.hybridization.items():
302309
generate_equinox(

0 commit comments

Comments
 (0)