Skip to content

Commit d78c9c0

Browse files
committed
observables
1 parent ec58463 commit d78c9c0

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

petab/v2/core.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Types around the PEtab object model."""
2+
from __future__ import annotations
3+
4+
from enum import Enum
5+
from pathlib import Path
6+
7+
import numpy as np
8+
import pandas as pd
9+
import sympy as sp
10+
from pydantic import (
11+
BaseModel,
12+
Field,
13+
ValidationInfo,
14+
field_validator,
15+
)
16+
17+
from ..v1.lint import is_valid_identifier
18+
from ..v1.math import sympify_petab
19+
from . import C
20+
21+
22+
class ObservableTransformation(str, Enum):
23+
LIN = C.LIN
24+
LOG = C.LOG
25+
LOG10 = C.LOG10
26+
27+
28+
class NoiseDistribution(str, Enum):
29+
NORMAL = C.NORMAL
30+
LAPLACE = C.LAPLACE
31+
32+
33+
class Observable(BaseModel):
34+
id: str = Field(alias=C.OBSERVABLE_ID)
35+
name: str | None = Field(alias=C.OBSERVABLE_NAME, default=None)
36+
formula: sp.Basic | None = Field(alias=C.OBSERVABLE_FORMULA, default=None)
37+
transformation: ObservableTransformation = Field(
38+
alias=C.OBSERVABLE_TRANSFORMATION, default=ObservableTransformation.LIN
39+
)
40+
noise_formula: sp.Basic | None = Field(alias=C.NOISE_FORMULA, default=None)
41+
noise_distribution: NoiseDistribution = Field(
42+
alias=C.NOISE_DISTRIBUTION, default=NoiseDistribution.NORMAL
43+
)
44+
45+
@field_validator("id")
46+
@classmethod
47+
def validate_id(cls, v):
48+
if not v:
49+
raise ValueError("ID must not be empty.")
50+
if not is_valid_identifier(v):
51+
raise ValueError(f"Invalid ID: {v}")
52+
return v
53+
54+
@field_validator(
55+
"name",
56+
"formula",
57+
"noise_formula",
58+
"noise_formula",
59+
"noise_distribution",
60+
"transformation",
61+
mode="before",
62+
)
63+
@classmethod
64+
def convert_nan_to_none(cls, v, info: ValidationInfo):
65+
if isinstance(v, float) and np.isnan(v):
66+
return cls.model_fields[info.field_name].default
67+
return v
68+
69+
@field_validator("formula", "noise_formula", mode="before")
70+
@classmethod
71+
def sympify(cls, v):
72+
if v is None or isinstance(v, sp.Basic):
73+
return v
74+
if isinstance(v, float) and np.isnan(v):
75+
return None
76+
77+
return sympify_petab(v)
78+
79+
class Config:
80+
populate_by_name = True
81+
arbitrary_types_allowed = True
82+
83+
84+
class ObservablesTable(BaseModel):
85+
observables: list[Observable]
86+
87+
@classmethod
88+
def from_dataframe(cls, df: pd.DataFrame) -> ObservablesTable:
89+
if df is None:
90+
return cls(observables=[])
91+
92+
observables = [
93+
Observable(**row.to_dict())
94+
for _, row in df.reset_index().iterrows()
95+
]
96+
97+
return cls(observables=observables)
98+
99+
def to_dataframe(self) -> pd.DataFrame:
100+
return pd.DataFrame(self.model_dump()["observables"])
101+
102+
@classmethod
103+
def from_tsv(cls, file_path: str | Path) -> ObservablesTable:
104+
df = pd.read_csv(file_path, sep="\t")
105+
return cls.from_dataframe(df)
106+
107+
def to_tsv(self, file_path: str | Path) -> None:
108+
df = self.to_dataframe()
109+
df.to_csv(file_path, sep="\t", index=False)

petab/v2/problem.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def __init__(
9393
] = default_validation_tasks.copy()
9494
self.config = config
9595

96+
from .core import Observable, ObservablesTable
97+
98+
self.observables: list[Observable] = ObservablesTable.from_dataframe(
99+
self.observable_df
100+
)
101+
96102
def __str__(self):
97103
model = f"with model ({self.model})" if self.model else "without model"
98104

tests/v2/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
from petab.v2.core import ObservablesTable
5+
6+
7+
def test_observables_table():
8+
file = (
9+
Path(__file__).parents[2]
10+
/ "doc/example/example_Fujita/Fujita_observables.tsv"
11+
)
12+
13+
# read-write-read round trip
14+
observables = ObservablesTable.from_tsv(file)
15+
16+
with tempfile.TemporaryDirectory() as tmp_dir:
17+
tmp_file = Path(tmp_dir) / "observables.tsv"
18+
observables.to_tsv(tmp_file)
19+
observables2 = ObservablesTable.from_tsv(tmp_file)
20+
assert observables == observables2

0 commit comments

Comments
 (0)