Skip to content

Commit f44944d

Browse files
committed
rework petabv2 jax test cases with ExperimentsToSbmlEvents and no v1 parameter_mapping
1 parent 788c3b5 commit f44944d

File tree

8 files changed

+257
-389
lines changed

8 files changed

+257
-389
lines changed

python/sdist/amici/importers/petab/v1/parameter_mapping.py

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import numpy as np
2727
import pandas as pd
2828
import petab.v1 as petab
29-
import petab.v2 as petabv2
3029
import sympy as sp
3130
from petab.v1.C import * # noqa: F403
3231
from petab.v1.C import (
@@ -42,9 +41,6 @@
4241

4342
from amici.importers.sbml import get_species_initial
4443
from amici.sim.sundials import AmiciModel, ParameterScaling
45-
from amici.sim.jax import (
46-
_set_initial_state_v2, get_states_in_condition_table_v2, reformat_for_v2, fixup_v2_parameter_mapping
47-
)
4844

4945
from . import PREEQ_INDICATOR_ID
5046
from .util import get_states_in_condition_table
@@ -384,15 +380,10 @@ def create_parameter_mapping(
384380
else petab_problem.mapping_df
385381
)
386382

387-
# Do some reformatting if V2 problem
388-
measurement_df, condition_df = reformat_for_v2(petab_problem) if isinstance(
389-
petab_problem, petabv2.Problem
390-
) else (petab_problem.measurement_df, petab_problem.condition_df)
391-
392383
prelim_parameter_mapping = (
393384
petab.get_optimization_to_simulation_parameter_mapping(
394-
condition_df=condition_df,
395-
measurement_df=measurement_df,
385+
condition_df=petab_problem.condition_df,
386+
measurement_df=petab_problem.measurement_df,
396387
parameter_df=petab_problem.parameter_df,
397388
observable_df=petab_problem.observable_df,
398389
mapping_df=mapping,
@@ -409,11 +400,6 @@ def create_parameter_mapping(
409400
for (_, condition), prelim_mapping_for_condition in zip(
410401
simulation_conditions.iterrows(), prelim_parameter_mapping, strict=True
411402
):
412-
if isinstance(petab_problem, petabv2.Problem):
413-
prelim_mapping_for_condition = fixup_v2_parameter_mapping(
414-
prelim_mapping_for_condition, petab_problem
415-
)
416-
417403
mapping_for_condition = create_parameter_mapping_for_condition(
418404
prelim_mapping_for_condition,
419405
condition,
@@ -488,14 +474,9 @@ def create_parameter_mapping_for_condition(
488474
# ExpData.x0, but in the case of pre-equilibration this would not allow for
489475
# resetting initial states.
490476

491-
if isinstance(petab_problem, petabv2.Problem):
492-
states_in_condition_table = get_states_in_condition_table_v2(
493-
petab_problem, condition
494-
)
495-
else:
496-
states_in_condition_table = get_states_in_condition_table(
497-
petab_problem, condition
498-
)
477+
states_in_condition_table = get_states_in_condition_table(
478+
petab_problem, condition
479+
)
499480

500481
if states_in_condition_table:
501482
# set indicator fixed parameter for preeq
@@ -536,26 +517,16 @@ def create_parameter_mapping_for_condition(
536517
# for simulation
537518
condition_id = condition[SIMULATION_CONDITION_ID]
538519
init_par_id = f"initial_{element_id}_sim"
539-
if isinstance(petab_problem, petabv2.Problem):
540-
_set_initial_state_v2(
541-
petab_problem,
542-
init_par_id,
543-
condition_map_sim,
544-
condition_scale_map_sim,
545-
value,
546-
fill_fixed_parameters=fill_fixed_parameters,
547-
)
548-
else:
549-
_set_initial_state(
550-
petab_problem,
551-
condition_id,
552-
element_id,
553-
init_par_id,
554-
condition_map_sim,
555-
condition_scale_map_sim,
556-
value,
557-
fill_fixed_parameters=fill_fixed_parameters,
558-
)
520+
_set_initial_state(
521+
petab_problem,
522+
condition_id,
523+
element_id,
524+
init_par_id,
525+
condition_map_sim,
526+
condition_scale_map_sim,
527+
value,
528+
fill_fixed_parameters=fill_fixed_parameters,
529+
)
559530
# set dummy value as above
560531
if condition_map_preeq:
561532
condition_map_preeq[init_par_id] = 0.0
@@ -609,17 +580,14 @@ def create_parameter_mapping_for_condition(
609580
)
610581
logger.debug(f"Variable parameters simulation: {condition_map_sim_var}")
611582

612-
if isinstance(petab_problem, petabv2.Problem):
613-
pass
614-
else:
615-
petab.merge_preeq_and_sim_pars_condition(
616-
condition_map_preeq_var,
617-
condition_map_sim_var,
618-
condition_scale_map_preeq_var,
619-
condition_scale_map_sim_var,
620-
condition,
621-
)
622-
logger.debug(f"Merged: {condition_map_sim_var}")
583+
petab.merge_preeq_and_sim_pars_condition(
584+
condition_map_preeq_var,
585+
condition_map_sim_var,
586+
condition_scale_map_preeq_var,
587+
condition_scale_map_sim_var,
588+
condition,
589+
)
590+
logger.debug(f"Merged: {condition_map_sim_var}")
623591

624592
if "sciml" in petab_problem.extensions_config:
625593
hybridizations = [

python/sdist/amici/importers/petab/v1/petab_import.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313

1414
import pandas as pd
1515
import petab.v1 as petab
16-
import petab.v2 as petabv2
1716
from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML
1817

1918
import amici
2019
from amici.logging import get_logger
21-
from amici.sim.jax import reformat_petab_v2_to_v1, add_events_to_sbml
2220

2321
from .import_helpers import (
2422
_can_import_model,
@@ -88,10 +86,6 @@ def import_petab_problem(
8886
:return:
8987
The imported model (if ``jax=False``) or JAX problem (if ``jax=True``).
9088
"""
91-
if isinstance(petab_problem, petabv2.Problem):
92-
petab_problem = add_events_to_sbml(petab_problem)
93-
petab_problem = reformat_petab_v2_to_v1(petab_problem)
94-
9589
if petab_problem.model.type_id not in (MODEL_TYPE_SBML, MODEL_TYPE_PYSB):
9690
raise NotImplementedError(
9791
"Unsupported model type " + petab_problem.model.type_id

python/sdist/amici/jax/_simulation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def body_fn(carry):
174174

175175
def solve(
176176
p: jt.Float[jt.Array, "np"],
177+
t0: jnp.float_,
177178
ts: jt.Float[jt.Array, "nt_dyn"],
178179
tcl: jt.Float[jt.Array, "ncl"],
179180
h: jt.Float[jt.Array, "ne"],
@@ -195,6 +196,8 @@ def solve(
195196
196197
:param p:
197198
parameters
199+
:param t0:
200+
initial time point
198201
:param ts:
199202
time points at which solutions are evaluated
200203
:param tcl:
@@ -226,7 +229,7 @@ def solve(
226229
if not root_cond_fns:
227230
# no events, we can just run a single segment
228231
sol, _, stats = _run_segment(
229-
0.0,
232+
t0,
230233
ts[-1],
231234
x0,
232235
p,
@@ -319,7 +322,7 @@ def body_fn(carry):
319322
body_fn,
320323
(
321324
jnp.zeros((ts.shape[0], x0.shape[0]), dtype=x0.dtype) + x0,
322-
0.0,
325+
t0,
323326
x0,
324327
jnp.zeros((ts.shape[0], h.shape[0]), dtype=h.dtype),
325328
h,

python/sdist/amici/jax/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,17 +599,27 @@ def simulate_condition_unjitted(
599599
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
600600
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
601601
h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]),
602+
t_zero: jnp.float_ = 0.0,
602603
ret: ReturnValue = ReturnValue.llh,
603604
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
604605
"""
605606
Unjitted version of simulate_condition.
606607
607608
See :meth:`simulate_condition` for full documentation.
608609
"""
609-
t0 = 0.0
610+
t0 = t_zero
610611
if p is None:
611612
p = self.parameters
612613

614+
if os.getenv("JAX_DEBUG") == "1":
615+
jax.debug.print(
616+
"x_reinit: {}, x_preeq: {}, x_def: {}. p: {}",
617+
x_reinit,
618+
x_preeq,
619+
self._x0(t0, p),
620+
p,
621+
)
622+
613623
if x_preeq.shape[0]:
614624
x = x_preeq
615625
elif init_override.shape[0]:
@@ -626,6 +636,7 @@ def simulate_condition_unjitted(
626636
# Re-initialization
627637
if x_reinit.shape[0]:
628638
x = jnp.where(mask_reinit, x_reinit, x)
639+
629640
x_solver = self._x_solver(x)
630641
tcl = self._tcl(x, p)
631642

@@ -645,6 +656,7 @@ def simulate_condition_unjitted(
645656
if ts_dyn.shape[0]:
646657
x_dyn, h_dyn, stats_dyn = solve(
647658
p,
659+
t0,
648660
ts_dyn,
649661
tcl,
650662
h,
@@ -781,6 +793,7 @@ def simulate_condition(
781793
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
782794
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
783795
h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]),
796+
t_zero: jnp.float_ = 0.0,
784797
ret: ReturnValue = ReturnValue.llh,
785798
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
786799
r"""
@@ -863,6 +876,7 @@ def simulate_condition(
863876
init_override_mask,
864877
ts_mask,
865878
h_mask,
879+
t_zero,
866880
ret,
867881
)
868882

@@ -929,6 +943,7 @@ def preequilibrate_condition(
929943
tcl,
930944
h,
931945
current_x,
946+
h_mask,
932947
solver,
933948
controller,
934949
root_finder,
@@ -992,7 +1007,7 @@ def _handle_t0_event(
9921007

9931008
if os.getenv("JAX_DEBUG") == "1":
9941009
jax.debug.print(
995-
"h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}, y0_next: {}, y0: {}",
1010+
"handle_t0_event h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}, y0_next: {}, y0: {}",
9961011
h,
9971012
rf0,
9981013
rfx,

python/sdist/amici/jax/ode_export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,20 @@ def __init__(
150150
"The JAX backend does not support models with algebraic states."
151151
)
152152

153-
if not ode_model.has_only_time_dependent_event_assignments():
154-
raise NotImplementedError(
155-
"The JAX backend does not support event assignments with explicit non-time dependent triggers."
156-
)
153+
# if not ode_model.has_only_time_dependent_event_assignments():
154+
# raise NotImplementedError(
155+
# "The JAX backend does not support event assignments with explicit non-time dependent triggers."
156+
# )
157157

158158
if ode_model.has_priority_events():
159159
raise NotImplementedError(
160160
"The JAX backend does not support event priorities."
161161
)
162162

163-
if ode_model.has_implicit_event_assignments():
164-
raise NotImplementedError(
165-
"The JAX backend does not support event assignments with implicit triggers."
166-
)
163+
# if ode_model.has_implicit_event_assignments():
164+
# raise NotImplementedError(
165+
# "The JAX backend does not support event assignments with implicit triggers."
166+
# )
167167

168168
self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG
169169

0 commit comments

Comments
 (0)