Skip to content

Commit 304f039

Browse files
committed
tidying - add docstrings - rm outputs in notebook
1 parent 9d107b3 commit 304f039

File tree

8 files changed

+69
-845
lines changed

8 files changed

+69
-845
lines changed

doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Lines changed: 48 additions & 805 deletions
Large diffs are not rendered by default.

python/sdist/amici/_symbolic/de_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2720,6 +2720,7 @@ def _args_containing_t(self, expr):
27202720
def has_implicit_event_assignments(self) -> bool:
27212721
"""
27222722
Checks whether the model has event assignments with implicit triggers
2723+
(i.e. triggers that are not time based).
27232724
27242725
:return:
27252726
boolean indicating if event assignments with implicit triggers are present

python/sdist/amici/_symbolic/de_model_components.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,9 @@ def get_trigger_times(self) -> set[sp.Expr]:
877877
return set(self._t_root)
878878

879879
def _implicit_symbols(self):
880+
"""Get implicit symbols in the event trigger function.
881+
That is, all symbols except time and petab indicator variables.
882+
"""
880883
symbols = [str(s) for s in list(self.get_val().free_symbols)]
881884
implicit_symbols = []
882885
for s in symbols:

python/sdist/amici/importers/petab/v1/parameter_mapping.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,11 +474,9 @@ def create_parameter_mapping_for_condition(
474474
# ExpData.x0, but in the case of pre-equilibration this would not allow for
475475
# resetting initial states.
476476

477-
states_in_condition_table = get_states_in_condition_table(
477+
if states_in_condition_table := get_states_in_condition_table(
478478
petab_problem, condition
479-
)
480-
481-
if states_in_condition_table:
479+
):
482480
# set indicator fixed parameter for preeq
483481
# (we expect here, that this parameter was added during import and
484482
# that it was not added by the user with a different meaning...)
@@ -527,7 +525,7 @@ def create_parameter_mapping_for_condition(
527525
value,
528526
fill_fixed_parameters=fill_fixed_parameters,
529527
)
530-
# set dummy value as above
528+
# set dummy value as above
531529
if condition_map_preeq:
532530
condition_map_preeq[init_par_id] = 0.0
533531
condition_scale_map_preeq[init_par_id] = LIN

python/sdist/amici/jax/model.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -615,15 +615,6 @@ def simulate_condition_unjitted(
615615
if not h_mask.shape[0]:
616616
h_mask = jnp.ones(self.n_events, dtype=jnp.bool_)
617617

618-
if os.getenv("JAX_DEBUG") == "1":
619-
jax.debug.print(
620-
"x_reinit: {}, x_preeq: {}, x_def: {}. p: {}",
621-
x_reinit,
622-
x_preeq,
623-
self._x0(t0, p),
624-
p,
625-
)
626-
627618
if x_preeq.shape[0]:
628619
x = x_preeq
629620
elif init_override.shape[0]:
@@ -913,6 +904,9 @@ def preequilibrate_condition(
913904
re-initialized state vector. If not provided, the state vector is not re-initialized.
914905
:param mask_reinit:
915906
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
907+
:param h_mask:
908+
mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it
909+
it marked as 1.0.
916910
:param solver:
917911
ODE solver
918912
:param controller:
@@ -1025,15 +1019,13 @@ def _handle_t0_event(
10251019

10261020
if os.getenv("JAX_DEBUG") == "1":
10271021
jax.debug.print(
1028-
"handle_t0_event h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}, y0_next: {}, y0: {}",
1022+
"h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}",
10291023
h,
10301024
rf0,
10311025
rfx,
10321026
roots_found,
10331027
roots_dir,
10341028
h_next,
1035-
y0_next,
1036-
y0,
10371029
)
10381030

10391031
return y0_next, t0_next, h_next, stats

python/sdist/amici/jax/petab.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,8 @@ def run_simulation(
14911491
Mask for states that need reinitialisation
14921492
:param x_reinit:
14931493
Reinitialisation values for states
1494+
:param h_mask:
1495+
Mask for the events that are part of the current experiment
14941496
:param solver:
14951497
ODE solver to use for simulation
14961498
:param controller:
@@ -1503,6 +1505,8 @@ def run_simulation(
15031505
:param x_preeq:
15041506
Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will
15051507
be initialised to the model default values.
1508+
:param h_preeq:
1509+
Pre-equilibration event mask. Can be empty if no pre-equilibration is available
15061510
:param ts_mask:
15071511
padding mask, see :meth:`JAXModel.simulate_condition` for details.
15081512
:param t_zeros:
@@ -1563,6 +1567,9 @@ def run_simulations(
15631567
:param preeq_array:
15641568
Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation
15651569
conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty.
1570+
:param h_preeqs:
1571+
Matrix of pre-equilibration event heaviside variables indicating whether an event condition is false or
1572+
true after preequilibration.
15661573
:param solver:
15671574
ODE solver to use for simulation.
15681575
:param controller:
@@ -1678,6 +1685,8 @@ def run_preequilibration(
16781685
Mask for states that need reinitialisation
16791686
:param x_reinit:
16801687
Reinitialisation values for states
1688+
:param h_mask:
1689+
Mask for the events that are part of the current experiment
16811690
:param solver:
16821691
ODE solver to use for simulation
16831692
:param controller:

python/sdist/amici/sim/jax/__init__.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,6 @@ def _conditions_to_experiment_map(experiment_df: pd.DataFrame) -> dict[str, str]
117117
}
118118
return condition_to_experiment
119119

120-
# def get_states_in_condition_table_v2(
121-
# petab_problem,
122-
# condition: dict | pd.Series = None,
123-
# ) -> dict[str, tuple[float | str | None, float | str | None]]:
124-
# """Get states and their initial condition as specified in the condition table.
125-
126-
# Returns: Dictionary: ``stateId -> (initial condition simulation)``
127-
# """
128-
# states = {
129-
# target_id: (target_value, None)
130-
# if condition_id == condition[petabv1.SIMULATION_CONDITION_ID]
131-
# else (None, None)
132-
# for condition_id, target_id, target_value in zip(
133-
# petab_problem.condition_df[petabv2.C.CONDITION_ID],
134-
# petab_problem.condition_df[petabv2.C.TARGET_ID],
135-
# petab_problem.condition_df[petabv2.C.TARGET_VALUE],
136-
# )
137-
# }
138-
139-
# return states
140-
141120
def _try_float(value):
142121
try:
143122
return float(value)

tests/petab_test_suite/test_petab_v2_suite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ def run():
245245
n_total = 0
246246
version = "v2.0.0"
247247

248-
# for jax in (False, True):
249-
for jax in (True):
248+
for jax in (False, True):
250249
cases = list(petabtests.get_cases("sbml", version=version))
251250
n_total += len(cases)
252251
for case in cases:

0 commit comments

Comments
 (0)