Skip to content

Commit 5b87701

Browse files
committed
dfs->objs
1 parent 31da444 commit 5b87701

File tree

5 files changed

+269
-114
lines changed

5 files changed

+269
-114
lines changed

petab/v2/core.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"ObservableTransformation",
2828
"NoiseDistribution",
2929
"Change",
30-
"ChangeSet",
30+
"Condition",
3131
"ConditionsTable",
3232
"OperationType",
3333
"ExperimentPeriod",
@@ -200,7 +200,25 @@ def from_df(cls, df: pd.DataFrame) -> ObservablesTable:
200200

201201
def to_df(self) -> pd.DataFrame:
202202
"""Convert the ObservablesTable to a DataFrame."""
203-
return pd.DataFrame(self.model_dump()["observables"])
203+
records = self.model_dump(by_alias=True)["observables"]
204+
for record in records:
205+
obs = record[C.OBSERVABLE_FORMULA]
206+
noise = record[C.NOISE_FORMULA]
207+
record[C.OBSERVABLE_FORMULA] = (
208+
None
209+
if obs is None
210+
else str(obs)
211+
if not obs.is_number
212+
else float(obs)
213+
)
214+
record[C.NOISE_FORMULA] = (
215+
None
216+
if noise is None
217+
else str(noise)
218+
if not noise.is_number
219+
else float(noise)
220+
)
221+
return pd.DataFrame(records).set_index([C.OBSERVABLE_ID])
204222

205223
@classmethod
206224
def from_tsv(cls, file_path: str | Path) -> ObservablesTable:
@@ -211,7 +229,7 @@ def from_tsv(cls, file_path: str | Path) -> ObservablesTable:
211229
def to_tsv(self, file_path: str | Path) -> None:
212230
"""Write the ObservablesTable to a TSV file."""
213231
df = self.to_df()
214-
df.to_csv(file_path, sep="\t", index=False)
232+
df.to_csv(file_path, sep="\t", index=True)
215233

216234
def __add__(self, other: Observable) -> ObservablesTable:
217235
"""Add an observable to the table."""
@@ -290,14 +308,14 @@ def _sympify(cls, v):
290308
return sympify_petab(v)
291309

292310

293-
class ChangeSet(BaseModel):
311+
class Condition(BaseModel):
294312
"""A set of changes to the model or model state.
295313
296314
A set of simultaneously occurring changes to the model or model state,
297315
corresponding to a perturbation of the underlying system. This corresponds
298316
to all rows of the PEtab conditions table with the same condition ID.
299317
300-
>>> ChangeSet(
318+
>>> Condition(
301319
... id="condition1",
302320
... changes=[
303321
... Change(
@@ -307,7 +325,7 @@ class ChangeSet(BaseModel):
307325
... )
308326
... ],
309327
... ) # doctest: +NORMALIZE_WHITESPACE
310-
ChangeSet(id='condition1', changes=[Change(target_id='k1',
328+
Condition(id='condition1', changes=[Change(target_id='k1',
311329
operation_type='setCurrentValue', target_value=10.0000000000000)])
312330
"""
313331

@@ -328,13 +346,13 @@ def _validate_id(cls, v):
328346
raise ValueError(f"Invalid ID: {v}")
329347
return v
330348

331-
def __add__(self, other: Change) -> ChangeSet:
349+
def __add__(self, other: Change) -> Condition:
332350
"""Add a change to the set."""
333351
if not isinstance(other, Change):
334352
raise TypeError("Can only add Change to ChangeSet")
335-
return ChangeSet(id=self.id, changes=self.changes + [other])
353+
return Condition(id=self.id, changes=self.changes + [other])
336354

337-
def __iadd__(self, other: Change) -> ChangeSet:
355+
def __iadd__(self, other: Change) -> Condition:
338356
"""Add a change to the set in place."""
339357
if not isinstance(other, Change):
340358
raise TypeError("Can only add Change to ChangeSet")
@@ -346,9 +364,9 @@ class ConditionsTable(BaseModel):
346364
"""PEtab conditions table."""
347365

348366
#: List of conditions.
349-
conditions: list[ChangeSet] = []
367+
conditions: list[Condition] = []
350368

351-
def __getitem__(self, condition_id: str) -> ChangeSet:
369+
def __getitem__(self, condition_id: str) -> Condition:
352370
"""Get a condition by ID."""
353371
for condition in self.conditions:
354372
if condition.id == condition_id:
@@ -364,18 +382,28 @@ def from_df(cls, df: pd.DataFrame) -> ConditionsTable:
364382
conditions = []
365383
for condition_id, sub_df in df.groupby(C.CONDITION_ID):
366384
changes = [Change(**row.to_dict()) for _, row in sub_df.iterrows()]
367-
conditions.append(ChangeSet(id=condition_id, changes=changes))
385+
conditions.append(Condition(id=condition_id, changes=changes))
368386

369387
return cls(conditions=conditions)
370388

371389
def to_df(self) -> pd.DataFrame:
372390
"""Convert the ConditionsTable to a DataFrame."""
373391
records = [
374-
{C.CONDITION_ID: condition.id, **change.model_dump()}
392+
{C.CONDITION_ID: condition.id, **change.model_dump(by_alias=True)}
375393
for condition in self.conditions
376394
for change in condition.changes
377395
]
378-
return pd.DataFrame(records)
396+
for record in records:
397+
record[C.TARGET_VALUE] = (
398+
float(record[C.TARGET_VALUE])
399+
if record[C.TARGET_VALUE].is_number
400+
else str(record[C.TARGET_VALUE])
401+
)
402+
return (
403+
pd.DataFrame(records)
404+
if records
405+
else pd.DataFrame(columns=C.CONDITION_DF_REQUIRED_COLS)
406+
)
379407

380408
@classmethod
381409
def from_tsv(cls, file_path: str | Path) -> ConditionsTable:
@@ -388,15 +416,15 @@ def to_tsv(self, file_path: str | Path) -> None:
388416
df = self.to_df()
389417
df.to_csv(file_path, sep="\t", index=False)
390418

391-
def __add__(self, other: ChangeSet) -> ConditionsTable:
419+
def __add__(self, other: Condition) -> ConditionsTable:
392420
"""Add a condition to the table."""
393-
if not isinstance(other, ChangeSet):
421+
if not isinstance(other, Condition):
394422
raise TypeError("Can only add ChangeSet to ConditionsTable")
395423
return ConditionsTable(conditions=self.conditions + [other])
396424

397-
def __iadd__(self, other: ChangeSet) -> ConditionsTable:
425+
def __iadd__(self, other: Condition) -> ConditionsTable:
398426
"""Add a condition to the table in place."""
399-
if not isinstance(other, ChangeSet):
427+
if not isinstance(other, Condition):
400428
raise TypeError("Can only add ChangeSet to ConditionsTable")
401429
self.conditions.append(other)
402430
return self
@@ -498,7 +526,19 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentsTable:
498526

499527
def to_df(self) -> pd.DataFrame:
500528
"""Convert the ExperimentsTable to a DataFrame."""
501-
return pd.DataFrame(self.model_dump()["experiments"])
529+
records = [
530+
{
531+
C.EXPERIMENT_ID: experiment.id,
532+
**period.model_dump(by_alias=True),
533+
}
534+
for experiment in self.experiments
535+
for period in experiment.periods
536+
]
537+
return (
538+
pd.DataFrame(records)
539+
if records
540+
else pd.DataFrame(columns=C.EXPERIMENT_DF_REQUIRED_COLS)
541+
)
502542

503543
@classmethod
504544
def from_tsv(cls, file_path: str | Path) -> ExperimentsTable:
@@ -617,7 +657,16 @@ def from_df(
617657

618658
def to_df(self) -> pd.DataFrame:
619659
"""Convert the MeasurementTable to a DataFrame."""
620-
return pd.DataFrame(self.model_dump()["measurements"])
660+
records = self.model_dump(by_alias=True)["measurements"]
661+
for record in records:
662+
record[C.OBSERVABLE_PARAMETERS] = C.PARAMETER_SEPARATOR.join(
663+
map(str, record[C.OBSERVABLE_PARAMETERS])
664+
)
665+
record[C.NOISE_PARAMETERS] = C.PARAMETER_SEPARATOR.join(
666+
map(str, record[C.NOISE_PARAMETERS])
667+
)
668+
669+
return pd.DataFrame(records)
621670

622671
@classmethod
623672
def from_tsv(cls, file_path: str | Path) -> MeasurementTable:
@@ -687,7 +736,12 @@ def from_df(cls, df: pd.DataFrame) -> MappingTable:
687736

688737
def to_df(self) -> pd.DataFrame:
689738
"""Convert the MappingTable to a DataFrame."""
690-
return pd.DataFrame(self.model_dump()["mappings"])
739+
res = (
740+
pd.DataFrame(self.model_dump(by_alias=True)["mappings"])
741+
if self.mappings
742+
else pd.DataFrame(columns=C.MAPPING_DF_REQUIRED_COLS)
743+
)
744+
return res.set_index([C.PETAB_ENTITY_ID])
691745

692746
@classmethod
693747
def from_tsv(cls, file_path: str | Path) -> MappingTable:
@@ -778,7 +832,9 @@ def from_df(cls, df: pd.DataFrame) -> ParameterTable:
778832

779833
def to_df(self) -> pd.DataFrame:
780834
"""Convert the ParameterTable to a DataFrame."""
781-
return pd.DataFrame(self.model_dump()["parameters"])
835+
return pd.DataFrame(
836+
self.model_dump(by_alias=True)["parameters"]
837+
).set_index([C.PARAMETER_ID])
782838

783839
@classmethod
784840
def from_tsv(cls, file_path: str | Path) -> ParameterTable:

0 commit comments

Comments
 (0)