Skip to content

Commit fb8ee2c

Browse files
committed
PEtab v2 import
Support for PEtab v2. **WIP**
1 parent baa723e commit fb8ee2c

File tree

9 files changed

+893
-51
lines changed

9 files changed

+893
-51
lines changed

pypesto/objective/amici/amici.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
try:
4242
import amici
43+
import amici.petab.petab_importer
44+
import pandas as pd
4345
from amici.petab.parameter_mapping import ParameterMapping
4446
except ImportError:
4547
pass
@@ -723,3 +725,67 @@ def update_from_problem(
723725
) in condition_mapping.map_preeq_fix.items():
724726
if (val := id_to_val.get(mapped_to_par)) is not None:
725727
condition_mapping.map_preeq_fix[model_par] = val
728+
729+
730+
class AmiciPetabV2Objective(AmiciObjective):
731+
"""An AMICI objective constructed from a PEtab v2 problem."""
732+
733+
def __init__(
734+
self,
735+
petab_importer: amici.petab.petab_importer.PetabImporter,
736+
**kwargs,
737+
) -> None:
738+
from .amici_calculator import AmiciCalculatorPetabV2
739+
740+
self._petab_simulator: amici.petab.petab_importer.PetabSimulator = (
741+
petab_importer.create_simulator()
742+
)
743+
self.petab_problem = petab_importer.petab_problem
744+
amici_model = self._petab_simulator.model
745+
amici_solver = self._petab_simulator.solver
746+
edatas = self._petab_simulator.exp_man.create_edatas()
747+
748+
super().__init__(
749+
amici_model=amici_model,
750+
amici_solver=amici_solver,
751+
edatas=edatas,
752+
calculator=AmiciCalculatorPetabV2(self._petab_simulator),
753+
**kwargs,
754+
)
755+
756+
def __deepcopy__(self, memo=None):
757+
"""Override AmiciObjective.__deepcopy__."""
758+
if memo is None:
759+
memo = {}
760+
cls = self.__class__
761+
result = cls.__new__(cls)
762+
memo[id(self)] = result
763+
for k, v in self.__dict__.items():
764+
setattr(result, k, copy.deepcopy(v, memo))
765+
return result
766+
767+
def __getstate__(self) -> dict:
768+
"""Use Python's default pickling semantics (shallow copy of instance dict)."""
769+
return dict(self.__dict__)
770+
771+
def __setstate__(self, state: dict) -> None:
772+
"""Restore state using the instance dict (default unpickling behaviour)."""
773+
self.__dict__.update(state)
774+
775+
def rdatas_to_simulation_df(
776+
self,
777+
rdatas: Sequence[amici.ReturnData],
778+
) -> pd.DataFrame:
779+
"""
780+
See :meth:`rdatas_to_measurement_df`.
781+
782+
Except a petab simulation dataframe is created, i.e. the measurement
783+
column label is adjusted.
784+
"""
785+
from amici.petab.petab_importer import rdatas_to_simulation_df
786+
787+
return rdatas_to_simulation_df(
788+
rdatas,
789+
self._petab_simulator._model,
790+
self._petab_simulator._petab_problem,
791+
)

pypesto/objective/amici/amici_calculator.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,136 @@ def __call__(
149149
)
150150

151151

152+
class AmiciCalculatorPetabV2(AmiciCalculator):
153+
"""Class to perform the AMICI call and obtain objective function values."""
154+
155+
def __init__(
156+
self,
157+
petab_simulator: amici.petab.petab_importer.PetabSimulator,
158+
**kwargs,
159+
):
160+
super().__init__(**kwargs)
161+
self.petab_simulator = petab_simulator
162+
163+
def __call__(
164+
self,
165+
x_dct: dict,
166+
sensi_orders: tuple[int],
167+
mode: ModeType,
168+
amici_model: AmiciModel,
169+
amici_solver: AmiciSolver,
170+
edatas: list[amici.ExpData],
171+
n_threads: int,
172+
x_ids: Sequence[str],
173+
parameter_mapping: ParameterMapping,
174+
fim_for_hess: bool,
175+
):
176+
"""Perform the actual AMICI call.
177+
178+
Called within the :func:`AmiciObjective.__call__` method.
179+
180+
Parameters
181+
----------
182+
x_dct:
183+
Parameters for which to compute function value and derivatives.
184+
sensi_orders:
185+
Tuple of requested sensitivity orders.
186+
mode:
187+
Call mode (function value or residual based).
188+
amici_model:
189+
The AMICI model.
190+
amici_solver:
191+
The AMICI solver.
192+
edatas:
193+
The experimental data.
194+
n_threads:
195+
Number of threads for AMICI call.
196+
x_ids:
197+
Ids of optimization parameters.
198+
parameter_mapping:
199+
Mapping of optimization to simulation parameters.
200+
fim_for_hess:
201+
Whether to use the FIM (if available) instead of the Hessian (if
202+
requested).
203+
"""
204+
amici_solver = self.petab_simulator._solver
205+
206+
if 2 in sensi_orders:
207+
raise NotImplementedError(
208+
"Second order sensitivities are not yet supported for "
209+
"PEtab v2."
210+
)
211+
212+
if mode != MODE_FUN:
213+
raise NotImplementedError(
214+
"Only function value mode is currently supported for "
215+
f"PEtab v2. Got mode {mode}."
216+
)
217+
218+
# TODO: -> method
219+
# set order in solver
220+
sensi_order = 0
221+
if sensi_orders:
222+
sensi_order = max(sensi_orders)
223+
224+
if sensi_order == 2 and fim_for_hess:
225+
# we use the FIM
226+
amici_solver.set_sensitivity_order(sensi_order - 1)
227+
else:
228+
amici_solver.set_sensitivity_order(sensi_order)
229+
230+
# run amici simulation
231+
res = self.petab_simulator.simulate(x_dct)
232+
rdatas = res[RDATAS]
233+
234+
if (
235+
not self._known_least_squares_safe
236+
and mode == MODE_RES
237+
and 1 in sensi_orders
238+
):
239+
if not amici_model.get_add_sigma_residuals() and any(
240+
(
241+
(r["ssigmay"] is not None and np.any(r["ssigmay"]))
242+
or (r["ssigmaz"] is not None and np.any(r["ssigmaz"]))
243+
)
244+
for r in rdatas
245+
):
246+
raise RuntimeError(
247+
"Cannot use least squares solver with"
248+
"parameter dependent sigma! Support can be "
249+
"enabled via "
250+
"amici_model.setAddSigmaResiduals()."
251+
)
252+
self._known_least_squares_safe = True # don't check this again
253+
254+
grad = None
255+
if 1 in sensi_orders:
256+
if res["sllh"] is None and np.isnan(res["llh"]):
257+
# TODO: to amici -- set sllh even if llh is nan
258+
grad = np.full(len(x_ids), np.nan)
259+
else:
260+
# llh to nllh, dict to array
261+
grad = -np.array(
262+
[
263+
res["sllh"][x_id] # if x_id in res["sllh"] else 0.0
264+
for x_id in x_ids
265+
if x_id in x_dct.keys()
266+
]
267+
)
268+
269+
ret = {
270+
FVAL: -res["llh"],
271+
GRAD: grad,
272+
# TODO
273+
# HESS: s2nllh,
274+
# RES: res,
275+
# SRES: sres,
276+
RDATAS: rdatas,
277+
}
278+
279+
return filter_return_dict(ret)
280+
281+
152282
def calculate_function_values(
153283
rdatas,
154284
sensi_orders: tuple[int, ...],

pypesto/petab/importer.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pandas as pd
1616
import petab.v1 as petab
17+
from petab import v2
1718

1819
try:
1920
import roadrunner
@@ -39,6 +40,7 @@
3940
from ..startpoint import StartpointMethod
4041
from .objective_creator import (
4142
AmiciObjectiveCreator,
43+
AmiciPetabV2ObjectiveCreator,
4244
ObjectiveCreator,
4345
PetabSimulatorObjectiveCreator,
4446
RoadRunnerObjectiveCreator,
@@ -70,7 +72,7 @@ class PetabImporter:
7072

7173
def __init__(
7274
self,
73-
petab_problem: petab.Problem,
75+
petab_problem: petab.Problem | v2.Problem,
7476
output_folder: str | None = None,
7577
model_name: str | None = None,
7678
validate_petab: bool = True,
@@ -149,8 +151,18 @@ def __init__(
149151

150152
self.validate_petab = validate_petab
151153
if self.validate_petab:
152-
if petab.lint_problem(petab_problem):
154+
if isinstance(petab_problem, petab.Problem) and petab.lint_problem(
155+
petab_problem
156+
):
153157
raise ValueError("Invalid PEtab problem.")
158+
if (
159+
isinstance(petab_problem, v2.Problem)
160+
and (
161+
validation_result := petab_problem.validate()
162+
).has_errors()
163+
):
164+
validation_result.log(logger=logger)
165+
raise ValueError("Invalid PEtab v2 problem.")
154166
if self._hierarchical and validate_petab_hierarchical:
155167
from ..hierarchical.petab import (
156168
validate_hierarchical_petab_problem,
@@ -186,7 +198,15 @@ def from_yaml(
186198
simulator_type: str = AMICI,
187199
) -> PetabImporter:
188200
"""Simplified constructor using a petab yaml file."""
189-
petab_problem = petab.Problem.from_yaml(yaml_config)
201+
from petab.versions import get_major_version
202+
203+
match get_major_version(yaml_config):
204+
case 1:
205+
petab_problem = petab.Problem.from_yaml(yaml_config)
206+
case 2:
207+
petab_problem = v2.Problem.from_yaml(yaml_config)
208+
case _:
209+
raise ValueError("Only PEtab v1 and v2 are supported.")
190210

191211
return PetabImporter(
192212
petab_problem=petab_problem,
@@ -278,6 +298,42 @@ def create_prior(self) -> NegLogParameterPriors | None:
278298
else:
279299
return None
280300

301+
def _create_prior_v2(self) -> NegLogParameterPriors | None:
302+
"""Create prior for PEtab v2 problem."""
303+
import petab.v2.C as petab_c
304+
305+
import pypesto.C as pypesto_c
306+
307+
petab_to_pypesto = {
308+
petab_c.LAPLACE: pypesto_c.LAPLACE,
309+
petab_c.LOG_LAPLACE: pypesto_c.LOG_LAPLACE,
310+
petab_c.LOG_NORMAL: pypesto_c.LOG_NORMAL,
311+
petab_c.LOG_UNIFORM: pypesto_c.LOG_UNIFORM,
312+
petab_c.NORMAL: pypesto_c.NORMAL,
313+
petab_c.UNIFORM: pypesto_c.UNIFORM,
314+
}
315+
316+
prior_list = []
317+
for parameter in self.petab_problem.parameters:
318+
if not parameter.estimate or parameter.prior_distribution is None:
319+
continue
320+
321+
prior_list.append(
322+
get_parameter_prior_dict(
323+
index=len(prior_list),
324+
# TODO map names
325+
prior_type=petab_to_pypesto.get(
326+
str(parameter.prior_distribution),
327+
str(parameter.prior_distribution),
328+
),
329+
prior_parameters=parameter.prior_parameters,
330+
parameter_scale="lin",
331+
)
332+
)
333+
if prior_list:
334+
return NegLogParameterPriors(prior_list)
335+
return None
336+
281337
def create_startpoint_method(self, **kwargs) -> StartpointMethod:
282338
"""Create a startpoint method.
283339
@@ -307,6 +363,22 @@ def create_objective_creator(
307363
has to be provided. Otherwise the argument is not used.
308364
309365
"""
366+
if isinstance(self.petab_problem, v2.Problem):
367+
if simulator_type != AMICI:
368+
raise ValueError(
369+
"Only 'amici' simulator type is supported for PEtab v2 "
370+
"problems."
371+
)
372+
return AmiciPetabV2ObjectiveCreator(
373+
petab_problem=self.petab_problem,
374+
output_folder=self.output_folder,
375+
model_name=self.model_name,
376+
hierarchical=self._hierarchical,
377+
inner_options=self.inner_options,
378+
non_quantitative_data_types=self._non_quantitative_data_types,
379+
validate_petab=self.validate_petab,
380+
)
381+
310382
if simulator_type == AMICI:
311383
return AmiciObjectiveCreator(
312384
petab_problem=self.petab_problem,
@@ -365,10 +437,26 @@ def create_problem(
365437
objective = self.objective_constructor.create_objective(**kwargs)
366438

367439
x_fixed_indices = self.petab_problem.x_fixed_indices
368-
x_fixed_vals = self.petab_problem.x_nominal_fixed_scaled
369440
x_ids = self.petab_problem.x_ids
370-
lb = self.petab_problem.lb_scaled
371-
ub = self.petab_problem.ub_scaled
441+
if isinstance(self.petab_problem, petab.Problem):
442+
# PEtab v1
443+
x_fixed_vals = self.petab_problem.x_nominal_fixed_scaled
444+
lb = self.petab_problem.lb_scaled
445+
ub = self.petab_problem.ub_scaled
446+
x_scales = [
447+
self.petab_problem.parameter_df.loc[
448+
x_id, petab.PARAMETER_SCALE
449+
]
450+
for x_id in x_ids
451+
]
452+
prior = self.create_prior()
453+
else:
454+
# PEtab v2 -- no parameter scaling
455+
x_fixed_vals = self.petab_problem.x_nominal_fixed
456+
lb = self.petab_problem.lb
457+
ub = self.petab_problem.ub
458+
x_scales = [petab.LIN for x_id in x_ids]
459+
prior = self._create_prior_v2()
372460

373461
# Raise error if the correct calculator is not used.
374462
if self._hierarchical:
@@ -388,19 +476,12 @@ def create_problem(
388476
map(x_ids.index, self.petab_problem.x_fixed_ids)
389477
)
390478

391-
x_scales = [
392-
self.petab_problem.parameter_df.loc[x_id, petab.PARAMETER_SCALE]
393-
for x_id in x_ids
394-
]
395-
396479
if problem_kwargs is None:
397480
problem_kwargs = {}
398481

399482
if startpoint_kwargs is None:
400483
startpoint_kwargs = {}
401484

402-
prior = self.create_prior()
403-
404485
if prior is not None:
405486
if self._hierarchical:
406487
raise NotImplementedError(

0 commit comments

Comments
 (0)