forked from AMICI-dev/AMICI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpysb_import.py
More file actions
318 lines (271 loc) · 10.9 KB
/
pysb_import.py
File metadata and controls
318 lines (271 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
PySB-PEtab Import
-----------------
Import a model in the PySB-adapted :mod:`petab`
(https://github.com/PEtab-dev/PEtab) format into AMICI.
"""
import logging
import re
from pathlib import Path
import petab.v1 as petab
import pysb
import pysb.bng
import sympy as sp
from petab.v1.C import (
CONDITION_NAME,
NOISE_FORMULA,
OBSERVABLE_FORMULA,
)
from petab.v1.models.pysb_model import PySBModel
from amici import MeasurementChannel
from ..logging import get_logger, log_execution_time, set_log_level
from . import PREEQ_INDICATOR_ID
from .import_helpers import (
get_fixed_parameters,
petab_noise_distribution_to_amici,
)
from .util import get_states_in_condition_table
logger = get_logger(__name__, logging.WARNING)
def _add_observation_model(
pysb_model: pysb.Model, petab_problem: petab.Problem, jax: bool = False
):
"""Extend PySB model by observation model as defined in the PEtab
observables table"""
# add any required output parameters
local_syms = {
sp.Symbol.__str__(comp): comp
for comp in pysb_model.components
if isinstance(comp, sp.Symbol)
}
obs_df = petab_problem.observable_df.copy()
for col, placeholder_pattern in (
(OBSERVABLE_FORMULA, r"^(observableParameter\d+)_\w+$"),
(NOISE_FORMULA, r"^(noiseParameter\d+)_\w+$"),
):
for ir, formula in petab_problem.observable_df[col].items():
if not isinstance(formula, str):
continue
changed_formula = False
sym = sp.sympify(formula, locals=local_syms)
for s in sym.free_symbols:
if not isinstance(s, pysb.Component):
if jax:
name = re.sub(placeholder_pattern, r"\1", str(s))
else:
name = str(s)
p = pysb.Parameter(name, 1.0)
pysb_model.add_component(p)
# placeholders for multiple observables are mapped to the same symbol, so only add to local_syms
# when necessary
if name not in local_syms:
local_syms[name] = p
# replace placeholder with parameter
if jax and name != str(s):
changed_formula = True
sym = sym.subs(s, local_syms[name])
# update forum
if jax and changed_formula:
obs_df.at[ir, col] = (
sym.name if isinstance(sym, sp.Symbol) else str(sym)
)
# add observables and sigmas to pysb model
for observable_id, observable_formula, noise_formula in zip(
obs_df.index,
obs_df[OBSERVABLE_FORMULA],
obs_df[NOISE_FORMULA],
strict=True,
):
obs_symbol = sp.sympify(observable_formula, locals=local_syms)
if observable_id in pysb_model.expressions.keys():
obs_expr = pysb_model.expressions[observable_id]
else:
obs_expr = pysb.Expression(observable_id, obs_symbol)
pysb_model.add_component(obs_expr)
local_syms[observable_id] = obs_expr
sigma_id = f"{observable_id}_sigma"
sigma_symbol = sp.sympify(noise_formula, locals=local_syms)
sigma_expr = pysb.Expression(sigma_id, sigma_symbol)
pysb_model.add_component(sigma_expr)
local_syms[sigma_id] = sigma_expr
def _add_initialization_variables(
pysb_model: pysb.Model, petab_problem: petab.Problem
):
"""Add initialization variables to the PySB model to support initial
conditions specified in the PEtab condition table.
To parameterize initial states, we currently need initial assignments.
If they occur in the condition table, we create a new parameter
initial_${speciesID}. Feels dirty and should be changed (see also #924).
"""
initial_states = get_states_in_condition_table(petab_problem)
fixed_parameters = []
if initial_states:
# add preequilibration indicator variable
# NOTE: would only be required if we actually have preequilibration
# adding it anyways. can be optimized-out later
if PREEQ_INDICATOR_ID in [c.name for c in pysb_model.components]:
raise AssertionError(
"Model already has a component with ID "
f"{PREEQ_INDICATOR_ID}. Cannot handle "
"species and compartments in condition table "
"then."
)
preeq_indicator = pysb.Parameter(PREEQ_INDICATOR_ID)
pysb_model.add_component(preeq_indicator)
# Can only reset parameters after preequilibration if they are fixed.
fixed_parameters.append(PREEQ_INDICATOR_ID)
logger.debug(
f"Adding preequilibration indicator constant {PREEQ_INDICATOR_ID}"
)
logger.debug(f"Adding initial assignments for {initial_states.keys()}")
for assignee_id in initial_states:
init_par_id_preeq = f"initial_{assignee_id}_preeq"
init_par_id_sim = f"initial_{assignee_id}_sim"
for init_par_id in [init_par_id_preeq, init_par_id_sim]:
if init_par_id in [c.name for c in pysb_model.components]:
raise ValueError(
"Cannot create parameter for initial assignment "
f"for {assignee_id} because an entity named "
f"{init_par_id} exists already in the model."
)
p = pysb.Parameter(init_par_id)
pysb_model.add_component(p)
species_idx = int(re.match(r"__s(\d+)$", assignee_id)[1])
# use original model here since that's what was used to generate
# the ids in initial_states
species_pattern = petab_problem.model.model.species[species_idx]
# species pattern comes from the _original_ model, but we only want
# to modify pysb_model, so we have to reconstitute the pattern using
# pysb_model
for c in pysb_model.components:
globals()[c.name] = c
species_pattern = pysb.as_complex_pattern(eval(str(species_pattern)))
from pysb.pattern import match_complex_pattern
formula = pysb.Expression(
f"initial_{assignee_id}_formula",
preeq_indicator * pysb_model.parameters[init_par_id_preeq]
+ (1 - preeq_indicator) * pysb_model.parameters[init_par_id_sim],
)
pysb_model.add_component(formula)
for initial in pysb_model.initials:
if match_complex_pattern(
initial.pattern, species_pattern, exact=True
):
logger.debug(
"The PySB model has an initial defined for species "
f"{assignee_id}, but this species also has an initial "
"value defined in the PEtab condition table. The SBML "
"initial assignment will be overwritten to handle "
"preequilibration and initial values specified by the "
"PEtab problem."
)
initial.value = formula
break
else:
# No initial in the pysb model, so add one
init = pysb.Initial(species_pattern, formula)
pysb_model.add_component(init)
return fixed_parameters
@log_execution_time("Importing PEtab model", logger)
def import_model_pysb(
petab_problem: petab.Problem,
model_output_dir: str | Path | None = None,
verbose: bool | int | None = True,
model_name: str | None = None,
jax: bool = False,
**kwargs,
) -> None:
"""
Create AMICI model from PySB-PEtab problem
:param petab_problem:
PySB PEtab problem
:param model_output_dir:
Directory to write the model code to. Will be created if doesn't
exist. Defaults to current directory.
:param verbose:
Print/log extra information.
:param model_name:
Name of the generated model module
:param jax:
Whether to generate JAX code instead of C++ code.
:param kwargs:
Additional keyword arguments to be passed to
:func:`amici.pysb_import.pysb2amici`.
"""
set_log_level(logger, verbose)
logger.info("Importing model ...")
if not isinstance(petab_problem.model, PySBModel):
raise ValueError("Not a PySB model")
# need to create a copy here as we don't want to modify the original
pysb.SelfExporter.cleanup()
og_export = pysb.SelfExporter.do_export
pysb.SelfExporter.do_export = False
pysb_model = pysb.Model(
base=petab_problem.model.model,
name=petab_problem.model.model_id,
)
_add_observation_model(pysb_model, petab_problem, jax)
# generate species for the _original_ model
pysb.bng.generate_equations(petab_problem.model.model)
fixed_parameters = _add_initialization_variables(pysb_model, petab_problem)
pysb.SelfExporter.do_export = og_export
# check condition table for supported features, important to use pysb_model
# here, as we want to also cover output parameters
model_parameters = [p.name for p in pysb_model.parameters]
condition_species_parameters = get_states_in_condition_table(
petab_problem, return_patterns=True
)
for x in petab_problem.condition_df.columns:
if x == CONDITION_NAME:
continue
x = petab.mapping.resolve_mapping(petab_problem.mapping_df, x)
# parameters
if x in model_parameters:
continue
# species/pattern
if x in condition_species_parameters:
continue
raise NotImplementedError(
"For PySB PEtab import, only model parameters and species, but "
"not compartments are allowed in the condition table. Offending "
f"column: {x}"
)
constant_parameters = (
get_fixed_parameters(petab_problem) + fixed_parameters
)
if petab_problem.observable_df is None:
observation_model = []
else:
observation_model = [
MeasurementChannel(
id_=observable.name,
sigma=f"{observable.name}_sigma",
noise_distribution=petab_noise_distribution_to_amici(
observable
),
)
for _, observable in petab_problem.observable_df.iterrows()
]
if jax:
from amici.pysb_import import pysb2jax
pysb2jax(
model=pysb_model,
output_dir=model_output_dir,
model_name=model_name,
verbose=True,
observation_model=observation_model,
pysb_model_has_obs_and_noise=True,
**kwargs,
)
return
else:
from amici.pysb_import import pysb2amici
pysb2amici(
model=pysb_model,
output_dir=model_output_dir,
model_name=model_name,
verbose=True,
constant_parameters=constant_parameters,
observation_model=observation_model,
pysb_model_has_obs_and_noise=True,
**kwargs,
)