Skip to content

Commit 0ed3e0f

Browse files
authored
Method to unflatten simulation df produced by flattened PEtab problem (#171)
1 parent 6c10c28 commit 0ed3e0f

File tree

2 files changed

+170
-28
lines changed

2 files changed

+170
-28
lines changed

petab/core.py

Lines changed: 157 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import logging
44
import os
55
import re
6-
from typing import Iterable, Optional, Callable, Union, Any, Sequence, List
6+
from typing import (
7+
Iterable, Optional, Callable, Union, Any, Sequence, List, Dict,
8+
)
79
from warnings import warn
810

911
import numpy as np
@@ -17,7 +19,16 @@
1719
'write_visualization_df', 'get_notnull_columns',
1820
'flatten_timepoint_specific_output_overrides',
1921
'concat_tables', 'to_float_if_float', 'is_empty',
20-
'create_combine_archive', 'unique_preserve_order']
22+
'create_combine_archive', 'unique_preserve_order',
23+
'unflatten_simulation_df']
24+
25+
POSSIBLE_GROUPVARS_FLATTENED_PROBLEM = [
26+
OBSERVABLE_ID,
27+
OBSERVABLE_PARAMETERS,
28+
NOISE_PARAMETERS,
29+
SIMULATION_CONDITION_ID,
30+
PREEQUILIBRATION_CONDITION_ID,
31+
]
2132

2233

2334
def get_simulation_df(simulation_file: Union[str, Path]) -> pd.DataFrame:
@@ -90,6 +101,99 @@ def get_notnull_columns(df: pd.DataFrame, candidates: Iterable):
90101
if col in df and not np.all(df[col].isnull())]
91102

92103

104+
def get_observable_replacement_id(groupvars, groupvar) -> str:
105+
"""Get the replacement ID for an observable.
106+
107+
Arguments:
108+
groupvars:
109+
The columns of a PEtab measurement table that should be unique
110+
between observables in a flattened PEtab problem.
111+
groupvar:
112+
A specific grouping of `groupvars`.
113+
114+
Returns:
115+
The observable replacement ID.
116+
"""
117+
replacement_id = ''
118+
for field in POSSIBLE_GROUPVARS_FLATTENED_PROBLEM:
119+
if field in groupvars:
120+
val = str(groupvar[groupvars.index(field)])\
121+
.replace(PARAMETER_SEPARATOR, '_').replace('.', '_')
122+
if replacement_id == '':
123+
replacement_id = val
124+
elif val != '':
125+
replacement_id += f'__{val}'
126+
return replacement_id
127+
128+
129+
def get_hyperparameter_replacement_id(
130+
hyperparameter_type,
131+
observable_replacement_id,
132+
):
133+
"""Get the full ID for a replaced hyperparameter.
134+
135+
Arguments:
136+
hyperparameter_type:
137+
The type of hyperparameter, e.g. `noiseParameter`.
138+
observable_replacement_id:
139+
The observable replacement ID, e.g. the output of
140+
`get_observable_replacement_id`.
141+
142+
Returns:
143+
The hyperparameter replacement ID, with a field that will be replaced
144+
by the first matched substring in a regex substitution.
145+
"""
146+
return f'{hyperparameter_type}\\1_{observable_replacement_id}'
147+
148+
149+
def get_flattened_id_mappings(
150+
petab_problem: 'petab.problem.Problem',
151+
) -> Dict[str, Dict[str, str]]:
152+
"""Get mapping from unflattened to flattened observable IDs.
153+
154+
Arguments:
155+
petab_problem:
156+
The unflattened PEtab problem.
157+
158+
Returns:
159+
A dictionary of dictionaries. Each inner dictionary is a mapping
160+
from original ID to flattened ID. Each outer dictionary is the mapping
161+
for either: observable IDs; noise parameter IDs; or, observable
162+
parameter IDs.
163+
"""
164+
groupvars = get_notnull_columns(petab_problem.measurement_df,
165+
POSSIBLE_GROUPVARS_FLATTENED_PROBLEM)
166+
mappings = {
167+
OBSERVABLE_ID: {},
168+
NOISE_PARAMETERS: {},
169+
OBSERVABLE_PARAMETERS: {},
170+
}
171+
for groupvar, measurements in \
172+
petab_problem.measurement_df.groupby(groupvars, dropna=False):
173+
observable_id = groupvar[groupvars.index(OBSERVABLE_ID)]
174+
observable_replacement_id = \
175+
get_observable_replacement_id(groupvars, groupvar)
176+
177+
logger.debug(f'Creating synthetic observable {observable_id}')
178+
if observable_replacement_id in petab_problem.observable_df.index:
179+
raise RuntimeError('could not create synthetic observables '
180+
f'since {observable_replacement_id} was '
181+
'already present in observable table')
182+
183+
mappings[OBSERVABLE_ID][observable_replacement_id] = observable_id
184+
185+
for field, hyperparameter_type, target in [
186+
(NOISE_PARAMETERS, 'noiseParameter', NOISE_FORMULA),
187+
(OBSERVABLE_PARAMETERS, 'observableParameter', OBSERVABLE_FORMULA)
188+
]:
189+
if field in measurements:
190+
mappings[field][get_hyperparameter_replacement_id(
191+
hyperparameter_type=hyperparameter_type,
192+
observable_replacement_id=observable_replacement_id,
193+
)] = fr'{hyperparameter_type}([0-9]+)_{observable_id}'
194+
return mappings
195+
196+
93197
def flatten_timepoint_specific_output_overrides(
94198
petab_problem: 'petab.problem.Problem',
95199
) -> None:
@@ -109,44 +213,38 @@ def flatten_timepoint_specific_output_overrides(
109213
"""
110214
new_measurement_dfs = []
111215
new_observable_dfs = []
112-
possible_groupvars = [OBSERVABLE_ID, OBSERVABLE_PARAMETERS,
113-
NOISE_PARAMETERS, SIMULATION_CONDITION_ID,
114-
PREEQUILIBRATION_CONDITION_ID]
115216
groupvars = get_notnull_columns(petab_problem.measurement_df,
116-
possible_groupvars)
217+
POSSIBLE_GROUPVARS_FLATTENED_PROBLEM)
218+
219+
mappings = get_flattened_id_mappings(petab_problem)
220+
117221
for groupvar, measurements in \
118222
petab_problem.measurement_df.groupby(groupvars, dropna=False):
119223
obs_id = groupvar[groupvars.index(OBSERVABLE_ID)]
120-
# construct replacement id
121-
replacement_id = ''
122-
for field in possible_groupvars:
123-
if field in groupvars:
124-
val = str(groupvar[groupvars.index(field)])\
125-
.replace(PARAMETER_SEPARATOR, '_').replace('.', '_')
126-
if replacement_id == '':
127-
replacement_id = val
128-
elif val != '':
129-
replacement_id += f'__{val}'
130-
131-
logger.debug(f'Creating synthetic observable {obs_id}')
132-
if replacement_id in petab_problem.observable_df.index:
133-
raise RuntimeError('could not create synthetic observables '
134-
f'since {replacement_id} was already '
135-
'present in observable table')
224+
observable_replacement_id = \
225+
get_observable_replacement_id(groupvars, groupvar)
226+
136227
observable = petab_problem.observable_df.loc[obs_id].copy()
137-
observable.name = replacement_id
138-
for field, parname, target in [
228+
observable.name = observable_replacement_id
229+
for field, hyperparameter_type, target in [
139230
(NOISE_PARAMETERS, 'noiseParameter', NOISE_FORMULA),
140231
(OBSERVABLE_PARAMETERS, 'observableParameter', OBSERVABLE_FORMULA)
141232
]:
142233
if field in measurements:
234+
hyperparameter_replacement_id = \
235+
get_hyperparameter_replacement_id(
236+
hyperparameter_type=hyperparameter_type,
237+
observable_replacement_id=observable_replacement_id,
238+
)
239+
hyperparameter_id = \
240+
mappings[field][hyperparameter_replacement_id]
143241
observable[target] = re.sub(
144-
fr'{parname}([0-9]+)_{obs_id}',
145-
f'{parname}\\1_{replacement_id}',
146-
observable[target]
242+
hyperparameter_id,
243+
hyperparameter_replacement_id,
244+
observable[target],
147245
)
148246

149-
measurements[OBSERVABLE_ID] = replacement_id
247+
measurements[OBSERVABLE_ID] = observable_replacement_id
150248
new_measurement_dfs.append(measurements)
151249
new_observable_dfs.append(observable)
152250

@@ -155,6 +253,37 @@ def flatten_timepoint_specific_output_overrides(
155253
petab_problem.measurement_df = pd.concat(new_measurement_dfs)
156254

157255

256+
def unflatten_simulation_df(
257+
simulation_df: pd.DataFrame,
258+
petab_problem: 'petab.problem.Problem',
259+
) -> None:
260+
"""Unflatten simulations from a flattened PEtab problem.
261+
262+
A flattened PEtab problem is the output of applying
263+
:func:`flatten_timepoint_specific_output_overrides` to a PEtab problem.
264+
265+
Arguments:
266+
simulation_df:
267+
The simulation dataframe. A dataframe in the same format as a PEtab
268+
measurements table, but with the ``measurement`` column switched
269+
with a ``simulation`` column.
270+
petab_problem:
271+
The unflattened PEtab problem.
272+
273+
Returns:
274+
The simulation dataframe for the unflattened PEtab problem.
275+
"""
276+
mappings = get_flattened_id_mappings(petab_problem)
277+
original_observable_ids = (
278+
simulation_df[OBSERVABLE_ID]
279+
.replace(mappings[OBSERVABLE_ID])
280+
)
281+
unflattened_simulation_df = simulation_df.assign(**{
282+
OBSERVABLE_ID: original_observable_ids,
283+
})
284+
return unflattened_simulation_df
285+
286+
158287
def concat_tables(
159288
tables: Union[str, Path, pd.DataFrame,
160289
Iterable[Union[pd.DataFrame, str, Path]]],

tests/test_petab.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,19 @@ def test_flatten_timepoint_specific_output_overrides():
372372

373373
assert petab.lint_problem(problem) is False
374374

375+
simulation_df = copy.deepcopy(problem.measurement_df)
376+
simulation_df.rename(columns={MEASUREMENT: SIMULATION})
377+
unflattened_problem = petab.Problem(
378+
measurement_df=measurement_df,
379+
observable_df=observable_df,
380+
)
381+
unflattened_simulation_df = petab.core.unflatten_simulation_df(
382+
simulation_df=simulation_df,
383+
petab_problem=unflattened_problem,
384+
)
385+
# The unflattened simulation dataframe has the original observable IDs.
386+
assert (unflattened_simulation_df[OBSERVABLE_ID] == 'obs1').all()
387+
375388

376389
def test_flatten_timepoint_specific_output_overrides_special_cases():
377390
"""Test flatten_timepoint_specific_output_overrides

0 commit comments

Comments
 (0)