diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index fc027d20c8..1ad19c8f52 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1595,6 +1595,11 @@ def _compute_equation(self, name: str) -> None: self._eqs[name] = event_eqs + elif name == "x_old": + self._eqs[name] = sp.Matrix( + [state.get_x_rdata() for state in self.states()] + ) + elif name == "z": event_observables = [ sp.zeros(self.num_eventobs(), 1) for _ in self._events diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index fd1a4cb45e..5b60d10c89 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -24,6 +24,7 @@ from amici._symbolic import DEModel, Event from amici.importers.utils import MeasurementChannel, amici_time_symbol from amici.logging import get_logger +from amici.jax.petab import JAXProblem from .v1.sbml_import import _add_global_parameter @@ -151,10 +152,6 @@ def __init__( "PEtab v2 importer currently only supports SBML and PySB " f"models. Got {self.petab_problem.model.type_id!r}." ) - if jax: - raise NotImplementedError( - "PEtab v2 importer currently does not support JAX. " - ) if self._debug: print("PetabImpoter.__init__: petab_problem:") @@ -577,6 +574,11 @@ def import_module(self, force_import: bool = False) -> amici.ModelModule: else: self._do_import_pysb() + if self._jax: + return amici.import_model_module( + Path(self.output_dir).stem, Path(self.output_dir).parent + ) + return amici.import_model_module( self._module_name, self.output_dir, @@ -601,6 +603,11 @@ def create_simulator( """ from amici.sim.sundials.petab import ExperimentManager, PetabSimulator + if self._jax: + model_module = self.import_module(force_import=force_import) + model = model_module.Model() + return JAXProblem(model, self.petab_problem) + model = self.import_module(force_import=force_import).get_model() em = ExperimentManager(model=model, petab_problem=self.petab_problem) return PetabSimulator(em=em) diff --git a/python/sdist/amici/importers/petab/v1/parameter_mapping.py b/python/sdist/amici/importers/petab/v1/parameter_mapping.py index b2b7837e7b..f20bebe727 100644 --- a/python/sdist/amici/importers/petab/v1/parameter_mapping.py +++ b/python/sdist/amici/importers/petab/v1/parameter_mapping.py @@ -355,7 +355,7 @@ def create_parameter_mapping( converter_config = ( libsbml.SBMLLocalParameterConverter().getDefaultProperties() ) - petab_problem.sbml_document.convert(converter_config) + petab_problem.model.sbml_document.convert(converter_config) else: logger.debug( "No petab_problem.sbml_document is set. Cannot " @@ -474,9 +474,11 @@ def create_parameter_mapping_for_condition( # ExpData.x0, but in the case of pre-equilibration this would not allow for # resetting initial states. - if states_in_condition_table := get_states_in_condition_table( + states_in_condition_table = get_states_in_condition_table( petab_problem, condition - ): + ) + + if states_in_condition_table: # set indicator fixed parameter for preeq # (we expect here, that this parameter was added during import and # that it was not added by the user with a different meaning...) @@ -525,7 +527,7 @@ def create_parameter_mapping_for_condition( value, fill_fixed_parameters=fill_fixed_parameters, ) - # set dummy value as above + # set dummy value as above if condition_map_preeq: condition_map_preeq[init_par_id] = 0.0 condition_scale_map_preeq[init_par_id] = LIN diff --git a/python/sdist/amici/importers/petab/v1/sbml_import.py b/python/sdist/amici/importers/petab/v1/sbml_import.py index f1953f9592..fd0a499099 100644 --- a/python/sdist/amici/importers/petab/v1/sbml_import.py +++ b/python/sdist/amici/importers/petab/v1/sbml_import.py @@ -3,10 +3,12 @@ import re from _collections import OrderedDict from itertools import chain +import pandas as pd from pathlib import Path import libsbml import petab.v1 as petab +import petab.v2 as petabv2 import sympy as sp from petab.v1.models import MODEL_TYPE_SBML from sympy.abc import _clash @@ -304,7 +306,10 @@ def import_model_sbml( if validate: logger.info("Validating PEtab problem ...") - petab.lint_problem(petab_problem) + if isinstance(petab_problem, petabv2.Problem): + petabv2.lint_problem(petab_problem) + else: + petab.lint_problem(petab_problem) # Model name from SBML ID or filename if model_name is None: diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 7b19e61517..c4376a24af 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -26,6 +26,7 @@ def eq( tcl: jt.Float[jt.Array, "ncl"], h0: jt.Float[jt.Array, "ne"], x0: jt.Float[jt.Array, "nxs"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -147,6 +148,7 @@ def body_fn(carry): term, root_cond_fn, delta_x, + h_mask, stats, ) @@ -172,10 +174,12 @@ def body_fn(carry): def solve( p: jt.Float[jt.Array, "np"], + t0: jnp.float_, ts: jt.Float[jt.Array, "nt_dyn"], tcl: jt.Float[jt.Array, "ncl"], h: jt.Float[jt.Array, "ne"], x0: jt.Float[jt.Array, "nxs"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -192,6 +196,8 @@ def solve( :param p: parameters + :param t0: + initial time point :param ts: time points at which solutions are evaluated :param tcl: @@ -223,7 +229,7 @@ def solve( if not root_cond_fns: # no events, we can just run a single segment sol, _, stats = _run_segment( - 0.0, + t0, ts[-1], x0, p, @@ -301,6 +307,7 @@ def body_fn(carry): term, root_cond_fn, delta_x, + h_mask, stats, ) @@ -315,7 +322,7 @@ def body_fn(carry): body_fn, ( jnp.zeros((ts.shape[0], x0.shape[0]), dtype=x0.dtype) + x0, - 0.0, + t0, x0, jnp.zeros((ts.shape[0], h.shape[0]), dtype=h.dtype), h, @@ -419,6 +426,7 @@ def _handle_event( term: diffrax.ODETerm, root_cond_fn: Callable, delta_x: Callable, + h_mask: jt.Bool[jt.Array, "ne"], stats: dict, ): args = (p, tcl, h) @@ -446,6 +454,8 @@ def _handle_event( delta_x, ) + h_next = jnp.where(h_mask, h_next, h) + if os.getenv("JAX_DEBUG") == "1": jax.debug.print( "rootvals: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 7a1636c651..785d5a06fb 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -598,6 +598,8 @@ def simulate_condition_unjitted( init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), + h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]), + t_zero: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: """ @@ -605,10 +607,19 @@ def simulate_condition_unjitted( See :meth:`simulate_condition` for full documentation. """ - t0 = 0.0 + t0 = t_zero if p is None: p = self.parameters + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "x_reinit: {}, x_preeq: {}, x_def: {}. p: {}", + x_reinit, + x_preeq, + self._x0(t0, p), + p, + ) + if x_preeq.shape[0]: x = x_preeq elif init_override.shape[0]: @@ -625,6 +636,7 @@ def simulate_condition_unjitted( # Re-initialization if x_reinit.shape[0]: x = jnp.where(mask_reinit, x_reinit, x) + x_solver = self._x_solver(x) tcl = self._tcl(x, p) @@ -636,6 +648,7 @@ def simulate_condition_unjitted( root_finder, self._root_cond_fn, self._delta_x, + h_mask, {}, ) @@ -643,10 +656,12 @@ def simulate_condition_unjitted( if ts_dyn.shape[0]: x_dyn, h_dyn, stats_dyn = solve( p, + t0, ts_dyn, tcl, h, x_solver, + h_mask, solver, controller, root_finder, @@ -671,6 +686,7 @@ def simulate_condition_unjitted( tcl, h, x_solver, + h_mask, solver, controller, root_finder, @@ -776,6 +792,8 @@ def simulate_condition( init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), + h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]), + t_zero: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: r""" @@ -828,6 +846,9 @@ def simulate_condition( :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. + :param h_mask: + mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it + it marked as 1.0. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: @@ -854,6 +875,8 @@ def simulate_condition( init_override, init_override_mask, ts_mask, + h_mask, + t_zero, ret, ) @@ -863,6 +886,7 @@ def preequilibrate_condition( p: jt.Float[jt.Array, "np"] | None, x_reinit: jt.Float[jt.Array, "*nx"], mask_reinit: jt.Bool[jt.Array, "*nx"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -910,6 +934,7 @@ def preequilibrate_condition( root_finder, self._root_cond_fn, self._delta_x, + h_mask, {}, ) @@ -918,6 +943,7 @@ def preequilibrate_condition( tcl, h, current_x, + h_mask, solver, controller, root_finder, @@ -941,10 +967,12 @@ def _handle_t0_event( root_finder: AbstractRootFinder, root_cond_fn: Callable, delta_x: Callable, + h_mask: jt.Bool[jt.Array, "ne"], stats: dict, ): + y0 = y0_next.copy() rf0 = self.event_initial_values - 0.5 - h = jnp.heaviside(rf0, 0.0) + h = jnp.where(h_mask, jnp.heaviside(rf0, 0.0), jnp.ones_like(rf0)) args = (p, tcl, h) rfx = root_cond_fn(t0_next, y0_next, args) roots_dir = jnp.sign(rfx - rf0) @@ -979,13 +1007,15 @@ def _handle_t0_event( if os.getenv("JAX_DEBUG") == "1": jax.debug.print( - "h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}", + "handle_t0_event h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}, y0_next: {}, y0: {}", h, rf0, rfx, roots_found, roots_dir, h_next, + y0_next, + y0, ) return y0_next, t0_next, h_next, stats diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 29b6458957..337b1e863c 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -149,16 +149,21 @@ def __init__( raise NotImplementedError( "The JAX backend does not support models with algebraic states." ) + + # if not ode_model.has_only_time_dependent_event_assignments(): + # raise NotImplementedError( + # "The JAX backend does not support event assignments with explicit non-time dependent triggers." + # ) if ode_model.has_priority_events(): raise NotImplementedError( "The JAX backend does not support event priorities." ) - if ode_model.has_implicit_event_assignments(): - raise NotImplementedError( - "The JAX backend does not support event assignments with implicit triggers." - ) + # if ode_model.has_implicit_event_assignments(): + # raise NotImplementedError( + # "The JAX backend does not support event assignments with implicit triggers." + # ) self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 4ed3c2713a..3bc6e21a4d 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -8,6 +8,7 @@ from numbers import Number from pathlib import Path +import os import diffrax import equinox as eqx import h5py @@ -17,7 +18,8 @@ import numpy as np import optimistix import pandas as pd -import petab.v1 as petab +import petab.v1 as petabv1 +import petab.v2 as petabv2 from optimistix import AbstractRootFinder from amici import _module_from_path @@ -27,6 +29,9 @@ ) from amici.jax.model import JAXModel, ReturnValue from amici.logging import get_logger +from amici.sim.jax import ( + add_default_experiment_names_to_v2_problem, get_simulation_conditions_v2, _build_simulation_df_v2, _try_float +) DEFAULT_CONTROLLER_SETTINGS = { "atol": 1e-8, @@ -42,9 +47,9 @@ } SCALE_TO_INT = { - petab.LIN: 0, - petab.LOG: 1, - petab.LOG10: 2, + petabv1.LIN: 0, + petabv1.LOG: 1, + petabv1.LOG10: 2, } logger = get_logger(__name__, logging.WARNING) @@ -60,30 +65,43 @@ def jax_unscale( parameter: Parameter to be unscaled. scale_str: - One of ``petab.LIN``, ``petab.LOG``, ``petab.LOG10``. + One of ``petabv1.LIN``, ``petabv1.LOG``, ``petabv1.LOG10``. Returns: The unscaled parameter. """ - if scale_str == petab.LIN or not scale_str: + if scale_str == petabv1.LIN or not scale_str: return parameter - if scale_str == petab.LOG: + if scale_str == petabv1.LOG: return jnp.exp(parameter) - if scale_str == petab.LOG10: + if scale_str == petabv1.LOG10: return jnp.power(10, parameter) raise ValueError(f"Invalid parameter scaling: {scale_str}") # IDEA: Implement this class in petab-sciml instead? -class HybridProblem(petab.Problem): +class HybridProblem(petabv1.Problem): hybridization_df: pd.DataFrame - def __init__(self, petab_problem: petab.Problem): + def __init__(self, petab_problem: petabv1.Problem): + self.__dict__.update(petab_problem.__dict__) + self.hybridization_df = _get_hybridization_df(petab_problem) + +class HybridV2Problem(petabv2.Problem): + hybridization_df: pd.DataFrame + extensions_config: dict + + def __init__(self, petab_problem: petabv2.Problem): + if not hasattr(petab_problem, "extensions_config"): + self.extensions_config = {} self.__dict__.update(petab_problem.__dict__) self.hybridization_df = _get_hybridization_df(petab_problem) def _get_hybridization_df(petab_problem): + if not hasattr(petab_problem, "extensions_config"): + return None + if "sciml" in petab_problem.extensions_config: hybridizations = [ pd.read_csv(hf, sep="\t", index_col=0) @@ -95,7 +113,9 @@ def _get_hybridization_df(petab_problem): return hybridization_df -def _get_hybrid_petab_problem(petab_problem: petab.Problem): +def _get_hybrid_petab_problem(petab_problem: petabv1.Problem | petabv2.Problem): + if isinstance(petab_problem, petabv2.Problem): + return HybridV2Problem(petab_problem) return HybridProblem(petab_problem) @@ -132,9 +152,9 @@ class JAXProblem(eqx.Module): _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray - _petab_problem: petab.Problem | HybridProblem + _petab_problem: petabv1.Problem | HybridProblem | petabv2.Problem - def __init__(self, model: JAXModel, petab_problem: petab.Problem): + def __init__(self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem): """ Initialize a JAXProblem instance with a model and a PEtab problem. @@ -143,13 +163,21 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): :param petab_problem: PEtab problem to simulate. """ - scs = petab_problem.get_simulation_conditions_from_measurement_df() - self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) + if isinstance(petab_problem, petabv2.Problem): + petab_problem = add_default_experiment_names_to_v2_problem(petab_problem) + scs = get_simulation_conditions_v2(petab_problem) + self.simulation_conditions = scs.simulationConditionId + else: + scs = petab_problem.get_simulation_conditions_from_measurement_df() + self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) self._petab_problem = _get_hybrid_petab_problem(petab_problem) self.parameters, self.model = ( self._initialize_model_with_nominal_values(model) ) - self._parameter_mappings = self._get_parameter_mappings(scs) + if isinstance(petab_problem, petabv1.Problem): + self._parameter_mappings = self._get_parameter_mappings(scs) + else: + self._parameter_mappings = None ( self._ts_dyn, self._ts_posteq, @@ -197,7 +225,7 @@ def load(cls, directory: Path): :return: Loaded problem instance. """ - petab_problem = petab.Problem.from_yaml( + petab_problem = petabv1.Problem.from_yaml( directory / "problem.yaml", ) model = _module_from_path("jax", directory / "jax_py_file.py").Model() @@ -213,22 +241,22 @@ def _get_parameter_mappings( :param simulation_conditions: Simulation conditions to create parameter mappings for. Same format as returned by - :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. + :meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`. :return: Dictionary mapping simulation conditions to parameter mappings. """ - scs = list(set(simulation_conditions.values.flatten())) + scs = list(set(simulation_conditions.simulationConditionId)) petab_problem = copy.deepcopy(self._petab_problem) # remove observable and noise parameters from measurement dataframe as we are mapping them elsewhere petab_problem.measurement_df.drop( - columns=[petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS], + columns=[petabv1.OBSERVABLE_PARAMETERS, petabv1.NOISE_PARAMETERS], inplace=True, errors="ignore", ) mappings = create_parameter_mapping( petab_problem=petab_problem, simulation_conditions=[ - {petab.SIMULATION_CONDITION_ID: sc} for sc in scs + {petabv1.SIMULATION_CONDITION_ID: sc} for sc in scs ], scaled_parameters=False, allow_timepoint_specific_numeric_noise_parameters=True, @@ -262,7 +290,7 @@ def _get_measurements( :param simulation_conditions: Simulation conditions to create parameter mappings for. Same format as returned by - :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. + :meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`. :return: tuple of padded - dynamic time points @@ -283,7 +311,7 @@ def _get_measurements( petab_indices = dict() n_pars = dict() - for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + for col in [petabv1.OBSERVABLE_PARAMETERS, petabv1.NOISE_PARAMETERS]: n_pars[col] = 0 if col in self._petab_problem.measurement_df: if np.issubdtype( @@ -295,7 +323,7 @@ def _get_measurements( else: n_pars[col] = ( self._petab_problem.measurement_df[col] - .str.split(petab.C.PARAMETER_SEPARATOR) + .str.split(petabv1.C.PARAMETER_SEPARATOR) .apply( lambda x: len(x) if isinstance(x, Sized) @@ -305,38 +333,54 @@ def _get_measurements( ) for _, simulation_condition in simulation_conditions.iterrows(): - query = " & ".join( - [f"{k} == '{v}'" for k, v in simulation_condition.items()] - ) + if "preequilibration" in simulation_condition[ + petabv1.SIMULATION_CONDITION_ID + ]: + continue + + if isinstance(self._petab_problem, HybridV2Problem): + query = " & ".join( + [ + f"{k} == '{v}'" + if isinstance(v, str) + else f"{k} == {v}" + for k, v in simulation_condition.items() + if k != petabv1.C.SIMULATION_CONDITION_ID + ] + ) + else: + query = " & ".join( + [f"{k} == '{v}'" for k, v in simulation_condition.items()] + ) m = self._petab_problem.measurement_df.query(query).sort_values( - by=petab.TIME + by=petabv1.TIME ) - ts = m[petab.TIME] + ts = m[petabv1.TIME] ts_dyn = ts[np.isfinite(ts)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] index = pd.concat([ts_dyn, ts_posteq]).index ts_dyn = ts_dyn.values ts_posteq = ts_posteq.values - my = m[petab.MEASUREMENT].values + my = m[petabv1.MEASUREMENT].values iys = np.array( [ self.model.observable_ids.index(oid) - for oid in m[petab.OBSERVABLE_ID].values + for oid in m[petabv1.OBSERVABLE_ID].values ] ) if ( - petab.OBSERVABLE_TRANSFORMATION + petabv1.OBSERVABLE_TRANSFORMATION in self._petab_problem.observable_df ): iy_trafos = np.array( [ SCALE_TO_INT[ self._petab_problem.observable_df.loc[ - oid, petab.OBSERVABLE_TRANSFORMATION + oid, petabv1.OBSERVABLE_TRANSFORMATION ] ] - for oid in m[petab.OBSERVABLE_ID].values + for oid in m[petabv1.OBSERVABLE_ID].values ] ) else: @@ -350,16 +394,16 @@ def get_parameter_override(x): if ( x in self._petab_problem.parameter_df.index and not self._petab_problem.parameter_df.loc[ - x, petab.ESTIMATE + x, petabv1.ESTIMATE ] ): return self._petab_problem.parameter_df.loc[ - x, petab.NOMINAL_VALUE + x, petabv1.NOMINAL_VALUE ] return x - for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: - if col not in m or m[col].isna().all(): + for col in [petabv1.OBSERVABLE_PARAMETERS, petabv1.NOISE_PARAMETERS]: + if col not in m or m[col].isna().all() or all(m[col] == ''): mat_numeric = jnp.ones((len(m), n_pars[col])) par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) @@ -368,7 +412,7 @@ def get_parameter_override(x): par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) else: - split_vals = m[col].str.split(petab.C.PARAMETER_SEPARATOR) + split_vals = m[col].str.split(petabv1.C.PARAMETER_SEPARATOR) list_vals = split_vals.apply( lambda x: [get_parameter_override(y) for y in x] if isinstance(x, list) @@ -413,15 +457,15 @@ def get_parameter_override(x): iys, # 3 iy_trafos, # 4 parameter_overrides_numeric_vals[ - petab.OBSERVABLE_PARAMETERS + petabv1.OBSERVABLE_PARAMETERS ], # 5 - parameter_overrides_mask[petab.OBSERVABLE_PARAMETERS], # 6 + parameter_overrides_mask[petabv1.OBSERVABLE_PARAMETERS], # 6 parameter_overrides_par_indices[ - petab.OBSERVABLE_PARAMETERS + petabv1.OBSERVABLE_PARAMETERS ], # 7 - parameter_overrides_numeric_vals[petab.NOISE_PARAMETERS], # 8 - parameter_overrides_mask[petab.NOISE_PARAMETERS], # 9 - parameter_overrides_par_indices[petab.NOISE_PARAMETERS], # 10 + parameter_overrides_numeric_vals[petabv1.NOISE_PARAMETERS], # 8 + parameter_overrides_mask[petabv1.NOISE_PARAMETERS], # 9 + parameter_overrides_par_indices[petabv1.NOISE_PARAMETERS], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) @@ -522,10 +566,16 @@ def pad_and_stack(output_index: int): ) def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: - simulation_conditions = ( - self._petab_problem.get_simulation_conditions_from_measurement_df() - ) - return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) + if isinstance(self._petab_problem, HybridV2Problem): + simulation_conditions = ( + get_simulation_conditions_v2(self._petab_problem) + ) + return tuple(tuple([row.simulationConditionId]) for _, row in simulation_conditions.iterrows()) + else: + simulation_conditions = ( + self._petab_problem.get_simulation_conditions_from_measurement_df() + ) + return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) def _initialize_model_parameters(self, model: JAXModel) -> dict: """ @@ -657,11 +707,11 @@ def _extract_nominal_values_from_petab( scalar = True # Determine value source (scalar from PEtab or array from file) - if np.isnan(row[petab.NOMINAL_VALUE]): + if np.isnan(row[petabv1.NOMINAL_VALUE]): value = par_arrays[net] scalar = False else: - value = float(row[petab.NOMINAL_VALUE]) + value = float(row[petabv1.NOMINAL_VALUE]) # Parse parameter name and set values to_set = self._parse_parameter_name(pname, model_pars) @@ -753,14 +803,14 @@ def _create_scaled_parameter_array(self) -> jt.Float[jt.Array, "np"]: """ return jnp.array( [ - petab.scale( + petabv1.scale( float( self._petab_problem.parameter_df.loc[ - pval, petab.NOMINAL_VALUE + pval, petabv1.NOMINAL_VALUE ] ), self._petab_problem.parameter_df.loc[ - pval, petab.PARAMETER_SCALE + pval, petabv1.PARAMETER_SCALE ], ) for pval in self.parameter_ids @@ -802,7 +852,19 @@ def _initialize_model_with_nominal_values( model = self._set_input_arrays(model, nn_input_arrays, model_pars) # Create scaled parameter array - parameter_array = self._create_scaled_parameter_array() + if isinstance(self._petab_problem, HybridV2Problem): + parameter_array = jnp.array( + [ + float( + self._petab_problem.parameter_df.loc[ + pval, petabv2.C.NOMINAL_VALUE + ] + ) + for pval in self.parameter_ids + ] + ) + else: + parameter_array = self._create_scaled_parameter_array() return parameter_array, model @@ -826,7 +888,7 @@ def _get_inputs(self) -> dict: .max(axis=0) + 1 ) - inputs[row["netId"]][row[petab.MODEL_ENTITY_ID]] = data_flat[ + inputs[row["netId"]][row[petabv1.MODEL_ENTITY_ID]] = data_flat[ "value" ].values.reshape(shape) return inputs @@ -839,11 +901,13 @@ def parameter_ids(self) -> list[str]: :return: PEtab parameter ids """ + if isinstance(self._petab_problem, HybridV2Problem): + return self._petab_problem.parameter_df[petabv2.C.ESTIMATE].index.tolist() return self._petab_problem.parameter_df[ - self._petab_problem.parameter_df[petab.ESTIMATE] + self._petab_problem.parameter_df[petabv1.ESTIMATE] == 1 & pd.to_numeric( - self._petab_problem.parameter_df[petab.NOMINAL_VALUE], + self._petab_problem.parameter_df[petabv1.NOMINAL_VALUE], errors="coerce", ).notna() ].index.tolist() @@ -858,8 +922,10 @@ def nn_output_ids(self) -> list[str]: """ if self._petab_problem.mapping_df is None: return [] + if self._petab_problem.mapping_df[petabv1.MODEL_ENTITY_ID].isnull().all(): + return [] return self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + self._petab_problem.mapping_df[petabv1.MODEL_ENTITY_ID] .str.split(".") .str[1] .str.startswith("output") @@ -895,7 +961,7 @@ def _unscale( def _eval_nn(self, output_par: str, condition_id: str): net_id = self._petab_problem.mapping_df.loc[ - output_par, petab.MODEL_ENTITY_ID + output_par, petabv1.MODEL_ENTITY_ID ].split(".")[0] nn = self.model.nns[net_id] @@ -905,12 +971,12 @@ def _is_net_input(model_id): model_id_map = ( self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID].apply( + self._petab_problem.mapping_df[petabv1.MODEL_ENTITY_ID].apply( _is_net_input ) ] .reset_index() - .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] + .set_index(petabv1.MODEL_ENTITY_ID)[petabv1.PETAB_ENTITY_ID] .to_dict() ) @@ -923,7 +989,7 @@ def _is_net_input(model_id): self._petab_problem.condition_df.loc[ condition_id, petab_id ], - petab.NOMINAL_VALUE, + petabv1.NOMINAL_VALUE, ], ) if self._petab_problem.condition_df.loc[ @@ -982,11 +1048,11 @@ def _is_net_input(model_id): else self.get_petab_parameter_by_id(petab_id) if petab_id in self.parameter_ids else self._petab_problem.parameter_df.loc[ - petab_id, petab.NOMINAL_VALUE + petab_id, petabv1.NOMINAL_VALUE ] if petab_id in set(self._petab_problem.parameter_df.index) else self._petab_problem.parameter_df.loc[ - hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE + hybridization_parameter_map[petab_id], petabv1.NOMINAL_VALUE ] for model_id, petab_id in model_id_map.items() ] @@ -1004,7 +1070,7 @@ def _map_model_parameter_value( nn_output = self._eval_nn(pval, condition_id) if nn_output.size > 1: entityId = self._petab_problem.mapping_df.loc[ - pval, petab.MODEL_ENTITY_ID + pval, petabv1.MODEL_ENTITY_ID ] ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) return nn_output[ind] @@ -1015,36 +1081,102 @@ def _map_model_parameter_value( return self.get_petab_parameter_by_id(pval) def load_model_parameters( - self, simulation_condition: str + self, experiment: petabv2.Experiment, is_preeq: bool ) -> jt.Float[jt.Array, "np"]: """ - Load parameters for a simulation condition. + Load parameters for an experiment. - :param simulation_condition: - Simulation condition to load parameters for. + :param experiment: + Experiment to load parameters for. + :param is_preeq: + Whether to load preequilibration or simulation parameters. :return: - Parameters for the simulation condition. + Parameters for the experiment. """ - mapping = self._parameter_mappings[simulation_condition] - p = jnp.array( [ - self._map_model_parameter_value( - mapping, pname, simulation_condition + self._map_experiment_model_parameter_value( + pname, ind, experiment, is_preeq ) - for pname in self.model.parameter_ids + for ind, pname in enumerate(self.model.parameter_ids) ] ) pscale = tuple( [ - petab.LIN - if self._petab_problem.mapping_df is not None - and pname in self._petab_problem.mapping_df.index - else mapping.scale_map_sim_var[pname] - for pname in self.model.parameter_ids + petabv1.LIN + for _ in self.model.parameter_ids ] ) + return self._unscale(p, pscale) + + def _map_experiment_model_parameter_value( + self, pname: str, p_index: int, experiment: petabv2.Experiment, is_preeq: bool + ): + """ + Get values for the given parameter `pname` from the relevant petab tables. + + :param pname: PEtab parameter id + :param p_index: Index of the parameter in the model's parameter list + :param experiment: PEtab experiment + :param is_preeq: Whether to get preequilibration or simulation parameter value + :return: Value of the parameter + """ + for p in experiment.periods: + if is_preeq: + if p.time >= 0.0: + continue + else: + condition_ids = p.condition_ids + break + else: + if p.time < 0.0: + continue + else: + condition_ids = p.condition_ids + break + + init_val = self.model.parameters[p_index] + if pname in self._petab_problem.parameter_df.index: + return self._petab_problem.parameter_df.loc[ + pname, petabv1.NOMINAL_VALUE + ] + elif pname in self._petab_problem.condition_df[petabv2.C.TARGET_ID].values: + target_row = self._petab_problem.condition_df[ + (self._petab_problem.condition_df[petabv2.C.TARGET_ID] == pname) & + (self._petab_problem.condition_df[petabv2.C.CONDITION_ID].isin(condition_ids)) + ] + if not target_row.empty: + target_value = target_row.iloc[0][petabv2.C.TARGET_VALUE] + return target_value + else: + for placeholder_col, param_col in ( + (petabv2.C.OBSERVABLE_PLACEHOLDERS, petabv2.C.OBSERVABLE_PARAMETERS), + (petabv2.C.NOISE_PLACEHOLDERS, petabv2.C.NOISE_PARAMETERS), + ): + placeholders = self._petab_problem.observable_df[ + placeholder_col + ].unique() + + for placeholders in placeholders: + placeholder_list = placeholders.split(";") + params_list = self._petab_problem.measurement_df[param_col][0].split(";") + for i, p in enumerate(placeholder_list): + if p == pname: + val = self._find_val(params_list[i]) + return val + return init_val + + def _find_val(self, param_entry: str): + val_float = _try_float(param_entry) + if isinstance(val_float, float): + return val_float + elif param_entry in self._petab_problem.parameter_df.index: + return self._petab_problem.parameter_df.loc[ + param_entry, petabv1.NOMINAL_VALUE + ] + else: + return param_entry # and hope I guess? def _state_needs_reinitialisation( self, @@ -1114,12 +1246,12 @@ def _state_reinitialisation_value( return jax_unscale( self.get_petab_parameter_by_id(xval), self._petab_problem.parameter_df.loc[ - xval, petab.PARAMETER_SCALE + xval, petabv1.PARAMETER_SCALE ], ) # only remaining option is nominal value for PEtab parameter # that is not estimated, return nominal value - return self._petab_problem.parameter_df.loc[xval, petab.NOMINAL_VALUE] + return self._petab_problem.parameter_df.loc[xval, petabv1.NOMINAL_VALUE] def load_reinitialisation( self, @@ -1169,9 +1301,11 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ return eqx.tree_at(lambda p: p.parameters, self, p) - def _prepare_conditions( + def _prepare_experiments( self, + experiments: list[petabv2.Experiment], conditions: list[str], + is_preeq: bool, op_numeric: np.ndarray | None = None, op_mask: np.ndarray | None = None, op_indices: np.ndarray | None = None, @@ -1186,10 +1320,14 @@ def _prepare_conditions( jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722 ]: """ - Prepare conditions for simulation. + Prepare experiments for simulation. + :param experiments: + Experiments to prepare simulation arrays for. :param conditions: Simulation conditions to prepare. + :param is_preeq: + Whether to load preequilibration or simulation parameters. :param op_numeric: Numeric values for observable parameter overrides. If None, no overrides are used. :param op_mask: @@ -1207,24 +1345,48 @@ def _prepare_conditions( noise parameters. """ p_array = jnp.stack( - [self.load_model_parameters(sc) for sc in conditions] + [self.load_model_parameters(exp, is_preeq) for exp in experiments] + ) + + # experiments by total number of events - each experiment needs to mask out the events that aren't + # relevant for that experiment + # TODO: remove the duplication if we get rid of the JAX-specific negative event duplication in SbmlImporter + def _experiment_event_inds(experiment_ind): + num_periods = len(self._petab_problem.experiments[experiment_ind].periods) + return jnp.arange(experiment_ind * 2, experiment_ind * 2 + (num_periods * 2)) + + h_mask = jnp.stack( + [jnp.isin(jnp.arange(self.model.n_events), _experiment_event_inds(i)) for i, _ in enumerate(experiments)] ) + t_zeros = jnp.stack([ + exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0 for exp in experiments + ]) + if self.parameters.size: - unscaled_parameters = jnp.stack( - [ - jax_unscale( - self.parameters[ip], - self._petab_problem.parameter_df.loc[ - p_id, petab.PARAMETER_SCALE - ], - ) - for ip, p_id in enumerate(self.parameter_ids) - ] - ) + if isinstance(self._petab_problem, HybridV2Problem): + unscaled_parameters = jnp.stack( + [ + self.parameters[ip] + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + else: + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petabv1.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) else: unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) + # placeholder values from sundials code may be needed here if op_numeric is not None and op_numeric.size: op_array = jnp.where( op_mask, @@ -1259,7 +1421,7 @@ def _prepare_conditions( for sc, p in zip(conditions, p_array) ] ) - return p_array, mask_reinit_array, x_reinit_array, op_array, np_array + return p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros @eqx.filter_vmap( in_axes={ @@ -1281,6 +1443,7 @@ def run_simulation( x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + h_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1290,13 +1453,14 @@ def run_simulation( max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ts_mask: np.ndarray = np.array([]), + t_zeros: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ - Run a simulation for a given simulation condition. + Run a simulation for a given simulation experiment. :param p: - Parameters for the simulation condition + Parameters for the simulation experiment :param ts_dyn: (Padded) dynamic time points :param ts_posteq: @@ -1329,6 +1493,8 @@ def run_simulation( be initialised to the model default values. :param ts_mask: padding mask, see :meth:`JAXModel.simulate_condition` for details. + :param t_zeros: + simulation start time for the current experiment. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: @@ -1349,6 +1515,8 @@ def run_simulation( init_override=init_override, init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), + h_mask=jax.lax.stop_gradient(jnp.array(h_mask)), + t_zero=t_zeros, solver=solver, controller=controller, root_finder=root_finder, @@ -1362,7 +1530,7 @@ def run_simulation( def run_simulations( self, - simulation_conditions: list[str], + experiments: list[petabv2.Experiment], preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, @@ -1374,10 +1542,10 @@ def run_simulations( ret: ReturnValue = ReturnValue.llh, ): """ - Run simulations for a list of simulation conditions. + Run simulations for a list of simulation experiments. - :param simulation_conditions: - List of simulation conditions to run simulations for. + :param experiments: + Experiments to run simulations for. :param preeq_array: Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty. @@ -1396,9 +1564,15 @@ def run_simulations( Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ - p_array, mask_reinit_array, x_reinit_array, op_array, np_array = ( - self._prepare_conditions( - simulation_conditions, + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) + + p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros = ( + self._prepare_experiments( + experiments, + dynamic_conditions, + False, self._op_numeric, self._op_mask, self._op_indices, @@ -1413,27 +1587,25 @@ def run_simulations( jnp.array( [ p - in set(self._parameter_mappings[sc].map_sim_var.keys()) + in set(self.model.parameter_ids) for p in self.model.state_ids ] ) - for sc in simulation_conditions + for _ in experiments ] ) init_override = jnp.stack( [ jnp.array( [ - self._eval_nn( - self._parameter_mappings[sc].map_sim_var[p], sc - ) + self._eval_nn(p, exp.periods[-1].condition_ids[0]) # TODO: Add mapping of p to eval_nn? if p - in set(self._parameter_mappings[sc].map_sim_var.keys()) + in set(self.model.parameter_ids) else 1.0 for p in self.model.state_ids ] ) - for sc in simulation_conditions + for exp in experiments ] ) @@ -1450,6 +1622,7 @@ def run_simulations( x_reinit_array, init_override, init_override_mask, + h_mask, solver, controller, root_finder, @@ -1457,6 +1630,7 @@ def run_simulations( max_steps, preeq_array, self._ts_masks, + t_zeros, ret, ) @@ -1471,6 +1645,7 @@ def run_preequilibration( p: jt.Float[jt.Array, "np"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + h_mask: jt.Bool[jt.Array, "ne"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1480,10 +1655,10 @@ def run_preequilibration( max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ - Run a pre-equilibration simulation for a given simulation condition. + Run a pre-equilibration simulation for a given simulation experiment. :param p: - Parameters for the simulation condition + Parameters for the simulation experiment :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: @@ -1504,6 +1679,7 @@ def run_preequilibration( p=p, mask_reinit=mask_reinit, x_reinit=x_reinit, + h_mask=h_mask, solver=solver, controller=controller, root_finder=root_finder, @@ -1513,7 +1689,7 @@ def run_preequilibration( def run_preequilibrations( self, - simulation_conditions: list[str], + experiments: list[petabv2.Experiment], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1522,13 +1698,19 @@ def run_preequilibrations( ], max_steps: jnp.int_, ): - p_array, mask_reinit_array, x_reinit_array, _, _ = ( - self._prepare_conditions(simulation_conditions, None, None) + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + preequilibration_conditions = list( + {sc for sc in simulation_conditions if "preequilibration" in sc} + ) + + p_array, mask_reinit_array, x_reinit_array, _, _, h_mask, _ = ( + self._prepare_experiments(experiments, preequilibration_conditions, True, None, None) ) return self.run_preequilibration( p_array, mask_reinit_array, x_reinit_array, + h_mask, solver, controller, root_finder, @@ -1539,7 +1721,7 @@ def run_preequilibrations( def run_simulations( problem: JAXProblem, - simulation_conditions: Iterable[tuple[str, ...]] | None = None, + simulation_experiments: Iterable[str] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS @@ -1558,10 +1740,9 @@ def run_simulations( :param problem: Problem to run simulations for. - :param simulation_conditions: - Simulation conditions to run simulations for. This is a series of tuples, where each tuple contains the - simulation condition or the pre-equilibration condition followed by the simulation condition. Default is to run - simulations for all conditions. + :param simulation_experiments: + Simulation experiments to run simulations for. This is an iterable of experiment ids. + Default is to run simulations for all experiments. :param solver: ODE solver to use for simulation. :param controller: @@ -1578,49 +1759,54 @@ def run_simulations( :return: Overall output value and condition specific results and statistics. """ + if isinstance(problem, HybridProblem) or isinstance(problem._petab_problem, petabv1.Problem): + raise TypeError( + "run_simulations does not support PEtab v1 problems. Upgrade the problem to PEtab v2." + ) + if isinstance(ret, str): ret = ReturnValue[ret] - if simulation_conditions is None: - simulation_conditions = problem.get_all_simulation_conditions() - - dynamic_conditions = [sc[0] for sc in simulation_conditions] - preequilibration_conditions = list( - {sc[1] for sc in simulation_conditions if len(sc) > 1} - ) + if simulation_experiments is None: + experiments = problem._petab_problem.experiments + else: + experiments = [exp for exp in problem._petab_problem.experiments if exp.id in simulation_experiments] + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) conditions = { "dynamic_conditions": dynamic_conditions, - "preequilibration_conditions": preequilibration_conditions, - "simulation_conditions": simulation_conditions, } - if preequilibration_conditions: + has_preeq = any(exp.periods[0].time < 0.0 for exp in experiments) + has_dynamic = any(exp.periods[-1].time >= 0.0 for exp in experiments) + + if has_preeq: preeqs, preresults = problem.run_preequilibrations( - preequilibration_conditions, + experiments, solver, controller, root_finder, steady_state_event, max_steps, ) + preeqs_array = preeqs else: preresults = { "stats_preeq": None, } - - if dynamic_conditions: - preeq_array = jnp.stack( + preeqs_array = jnp.stack( [ - preeqs[preequilibration_conditions.index(sc[1]), :] - if len(sc) > 1 - else jnp.array([]) - for sc in simulation_conditions + jnp.array([]) + for _ in experiments ] ) + + if has_dynamic: output, results = problem.run_simulations( - dynamic_conditions, - preeq_array, + experiments, + preeqs_array, solver, controller, root_finder, @@ -1639,6 +1825,11 @@ def run_simulations( } if ret in (ReturnValue.llh, ReturnValue.chi2): + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "ret: {}", + ret, + ) output = jnp.sum(output) return output, results | preresults | conditions @@ -1680,50 +1871,55 @@ def petab_simulate( max_steps=max_steps, ret=ReturnValue.y, ) - dfs = [] - for ic, sc in enumerate(r["dynamic_conditions"]): - obs = [ - problem.model.observable_ids[io] - for io in problem._iys[ic, problem._ts_masks[ic, :]] - ] - t = jnp.concat( - ( - problem._ts_dyn[ic, :], - problem._ts_posteq[ic, :], - ) - ) - df_sc = pd.DataFrame( - { - petab.SIMULATION: y[ic, problem._ts_masks[ic, :]], - petab.TIME: t[problem._ts_masks[ic, :]], - petab.OBSERVABLE_ID: obs, - petab.SIMULATION_CONDITION_ID: [sc] * len(t), - }, - index=problem._petab_measurement_indices[ic, :], - ) - if ( - petab.OBSERVABLE_PARAMETERS - in problem._petab_problem.measurement_df - ): - df_sc[petab.OBSERVABLE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.OBSERVABLE_PARAMETERS] - ) - if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df: - df_sc[petab.NOISE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.NOISE_PARAMETERS] + if isinstance(problem._petab_problem, HybridV2Problem): + return _build_simulation_df_v2(problem, y, r["dynamic_conditions"]) + else: + dfs = [] + for ic, sc in enumerate(r["dynamic_conditions"]): + obs = [ + problem.model.observable_ids[io] + for io in problem._iys[ic, problem._ts_masks[ic, :]] + ] + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) ) - if ( - petab.PREEQUILIBRATION_CONDITION_ID - in problem._petab_problem.measurement_df - ): - df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.PREEQUILIBRATION_CONDITION_ID] + df_sc = pd.DataFrame( + { + petabv1.SIMULATION: y[ic, problem._ts_masks[ic, :]], + petabv1.TIME: t[problem._ts_masks[ic, :]], + petabv1.OBSERVABLE_ID: obs, + petabv1.SIMULATION_CONDITION_ID: [sc] * len(t), + }, + index=problem._petab_measurement_indices[ic, :], ) - dfs.append(df_sc) - return pd.concat(dfs).sort_index() + if ( + petabv1.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petabv1.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv1.SIMULATION_CONDITION_ID} == '{sc}'" + )[petabv1.OBSERVABLE_PARAMETERS] + ) + if petabv1.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petabv1.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv1.SIMULATION_CONDITION_ID} == '{sc}'" + )[petabv1.NOISE_PARAMETERS] + ) + if ( + petabv1.PREEQUILIBRATION_CONDITION_ID + in problem._petab_problem.measurement_df + ): + df_sc[petabv1.PREEQUILIBRATION_CONDITION_ID] = ( + problem._petab_problem.measurement_df.query( + f"{petabv1.SIMULATION_CONDITION_ID} == '{sc}'" + )[petabv1.PREEQUILIBRATION_CONDITION_ID] + ) + dfs.append(df_sc) + return pd.concat(dfs).sort_index() + + diff --git a/python/sdist/amici/sim/jax/__init__.py b/python/sdist/amici/sim/jax/__init__.py index 07744420df..a1cb94f564 100644 --- a/python/sdist/amici/sim/jax/__init__.py +++ b/python/sdist/amici/sim/jax/__init__.py @@ -1 +1,148 @@ """Functionality for simulating JAX-based AMICI models.""" + +import petab.v1 as petabv1 +import petab.v2 as petabv2 + +import pandas as pd +import jax.numpy as jnp + +def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): + """Add default experiment names to PEtab v2 problem. + + Args: + petab_problem: PEtab v2 problem to modify. + """ + if not hasattr(petab_problem, "extensions_config"): + petab_problem.extensions_config = {} + + petab_problem.visualization_df = None + + if petab_problem.condition_df is None: + default_condition = petabv2.core.Condition(id="__default__", changes=[], conditionId="__default__") + petab_problem.condition_tables[0].elements = [default_condition] + + if petab_problem.experiment_df is None or petab_problem.experiment_df.empty: + condition_ids = petab_problem.condition_df[petabv2.C.CONDITION_ID].values + condition_ids = [c for c in condition_ids if "preequilibration" not in c] + default_experiment = petabv2.core.Experiment( + id="__default__", + periods=[ + petabv2.core.ExperimentPeriod( + time=0.0, + condition_ids=condition_ids + ) + ], + ) + petab_problem.experiment_tables[0].elements = [default_experiment] + + measurement_tables = petab_problem.measurement_tables.copy() + for mt in measurement_tables: + for m in mt.elements: + m.experiment_id = "__default__" + + petab_problem.measurement_tables = measurement_tables + + return petab_problem + +def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: + """Get simulation conditions from PEtab v2 measurement DataFrame. + + Returns: + A pandas DataFrame mapping experiment_ids to condition ids. + """ + experiment_df = petab_problem.experiment_df + exps = {} + for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique(): + exps[exp_id] = experiment_df[ + experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id + ][petabv2.C.CONDITION_ID].unique() + + experiment_df = experiment_df.rename(columns={"conditionId": "simulationConditionId"}) + experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) + return experiment_df + +def _build_simulation_df_v2(problem, y, dyn_conditions): + """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" + dfs = [] + for ic, sc in enumerate(dyn_conditions): + experiment_id = _conditions_to_experiment_map( + problem._petab_problem.experiment_df + )[sc] + + if experiment_id == "__default__": + experiment_id = jnp.nan + + obs = [ + problem.model.observable_ids[io] + for io in problem._iys[ic, problem._ts_masks[ic, :]] + ] + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) + ) + df_sc = pd.DataFrame( + { + petabv2.C.MODEL_ID: [float("nan")] * len(t), + petabv1.OBSERVABLE_ID: obs, + petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t), + petabv1.TIME: t[problem._ts_masks[ic, :]], + petabv1.SIMULATION: y[ic, problem._ts_masks[ic, :]], + }, + index=problem._petab_measurement_indices[ic, :], + ) + if ( + petabv1.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.OBSERVABLE_PARAMETERS] + ) + if petabv1.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petabv2.C.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.NOISE_PARAMETERS] + ) + dfs.append(df_sc) + return pd.concat(dfs).sort_index() + +def _conditions_to_experiment_map(experiment_df: pd.DataFrame) -> dict[str, str]: + condition_to_experiment = { + row.conditionId: row.experimentId + for row in experiment_df.itertuples() + } + return condition_to_experiment + +# def get_states_in_condition_table_v2( +# petab_problem, +# condition: dict | pd.Series = None, +# ) -> dict[str, tuple[float | str | None, float | str | None]]: +# """Get states and their initial condition as specified in the condition table. + +# Returns: Dictionary: ``stateId -> (initial condition simulation)`` +# """ +# states = { +# target_id: (target_value, None) +# if condition_id == condition[petabv1.SIMULATION_CONDITION_ID] +# else (None, None) +# for condition_id, target_id, target_value in zip( +# petab_problem.condition_df[petabv2.C.CONDITION_ID], +# petab_problem.condition_df[petabv2.C.TARGET_ID], +# petab_problem.condition_df[petabv2.C.TARGET_VALUE], +# ) +# } + +# return states + +def _try_float(value): + try: + return float(value) + except Exception as e: + msg = str(e).lower() + if isinstance(e, ValueError) and "could not convert" in msg: + return value + raise \ No newline at end of file diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 52d0d053ab..5b55852a44 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -44,6 +44,11 @@ def test_jax_llh(benchmark_problem): f"Skipping {problem_id} due to non-supported events in JAX." ) + if problem_id == "Oliveira_NatCommun2021": + pytest.skip( + "Skipping Oliveira_NatCommun2021 due to non-supported events in JAX." + ) + amici_solver = amici_model.create_solver() cur_settings = settings[problem_id] amici_solver.set_absolute_tolerance(1e-8) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 9271386e4a..94c829fd7c 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -50,6 +50,8 @@ def test_case(case, model_type, version, jax): f"implemented: {e}" ) pytest.skip(str(e)) + elif "run_simulations does not support PEtab v1" in str(e): + pytest.skip(str(e)) else: raise e diff --git a/tests/petab_test_suite/test_petab_v2_suite.py b/tests/petab_test_suite/test_petab_v2_suite.py index ba98eb3931..81b78399c7 100755 --- a/tests/petab_test_suite/test_petab_v2_suite.py +++ b/tests/petab_test_suite/test_petab_v2_suite.py @@ -4,6 +4,7 @@ import logging import sys +import diffrax import pandas as pd import petabtests import pytest @@ -66,28 +67,56 @@ def _test_case(case, model_type, version, jax): f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}" ) - pi = PetabImporter( - petab_problem=problem, - module_name=model_name, - compile_=True, - jax=jax, - ) - ps = pi.create_simulator( - force_import=True, - ) - ps.solver.set_steady_state_tolerance_factor(1.0) - - problem_parameters = problem.get_x_nominal_dict(free=True, fixed=True) - res = ps.simulate(problem_parameters=problem_parameters) - rdatas = res.rdatas - for rdata in rdatas: - assert rdata.status == AMICI_SUCCESS, ( - f"Simulation failed for {rdata.id}" + if jax: + from amici.jax import petab_simulate, run_simulations + + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) + + pi = PetabImporter( + petab_problem=problem, + module_name=model_name, + compile_=True, + jax=jax, ) - chi2 = sum(rdata.chi2 for rdata in rdatas) - llh = res.llh - simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) + + jax_problem = pi.create_simulator( + force_import=True, + ) + + llh, ret = run_simulations( + jax_problem, steady_state_event=steady_state_event + ) + chi2, _ = run_simulations( + jax_problem, ret="chi2", steady_state_event=steady_state_event + ) + simulation_df = petab_simulate( + jax_problem, steady_state_event=steady_state_event + ) + else: + pi = PetabImporter( + petab_problem=problem, + module_name=model_name, + compile_=True, + jax=jax, + ) + + ps = pi.create_simulator( + force_import=True, + ) + ps.solver.set_steady_state_tolerance_factor(1.0) + + problem_parameters = problem.get_x_nominal_dict(free=True, fixed=True) + res = ps.simulate(problem_parameters=problem_parameters) + + rdatas = res.rdatas + for rdata in rdatas: + assert rdata.status == AMICI_SUCCESS, ( + f"Simulation failed for {rdata.id}" + ) + chi2 = sum(rdata.chi2 for rdata in rdatas) + llh = res.llh + simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) solution = petabtests.load_solution(case, model_type, version=version) gt_chi2 = solution[petabtests.CHI2] @@ -194,24 +223,25 @@ def run(): n_skipped = 0 n_total = 0 version = "v2.0.0" - jax = False - - cases = list(petabtests.get_cases("sbml", version=version)) - n_total += len(cases) - for case in cases: - try: - test_case(case, "sbml", version=version, jax=jax) - n_success += 1 - except Skipped: - n_skipped += 1 - except Exception as e: - # run all despite failures - logger.error(f"Case {case} failed.") - logger.exception(e) - - logger.info(f"{n_success} / {n_total} successful, {n_skipped} skipped") - if n_success != len(cases): - sys.exit(1) + + # for jax in (False, True): + for jax in (True): + cases = list(petabtests.get_cases("sbml", version=version)) + n_total += len(cases) + for case in cases: + try: + test_case(case, "sbml", version=version, jax=jax) + n_success += 1 + except Skipped: + n_skipped += 1 + except Exception as e: + # run all despite failures + logger.error(f"Case {case} failed.") + logger.exception(e) + + logger.info(f"{n_success} / {n_total} successful, {n_skipped} skipped") + if n_success != len(cases): + sys.exit(1) if __name__ == "__main__":