Skip to content

Commit debae6e

Browse files
committed
v2: Allow applying multiple conditions simultaneously
1 parent 8614c02 commit debae6e

File tree

3 files changed

+52
-35
lines changed

3 files changed

+52
-35
lines changed

petab/v2/core.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def free_symbols(self) -> set[sp.Symbol]:
475475

476476
class ExperimentPeriod(BaseModel):
477477
"""A period of a timecourse or experiment defined by a start time
478-
and a condition ID.
478+
and a list of condition IDs.
479479
480480
This corresponds to a row of the PEtab experiments table.
481481
"""
@@ -484,20 +484,19 @@ class ExperimentPeriod(BaseModel):
484484
time: Annotated[float, AfterValidator(_is_finite_or_neg_inf)] = Field(
485485
alias=C.TIME
486486
)
487-
#: The ID of the condition to be applied at the start time.
488-
condition_id: str | None = Field(alias=C.CONDITION_ID, default=None)
487+
#: The IDs of the conditions to be applied at the start time.
488+
condition_ids: list[str] = []
489489

490490
#: :meta private:
491491
model_config = ConfigDict(populate_by_name=True, extra="allow")
492492

493-
@field_validator("condition_id", mode="before")
493+
@field_validator("condition_ids", mode="before")
494494
@classmethod
495-
def _validate_id(cls, condition_id):
496-
if pd.isna(condition_id) or not condition_id:
497-
return None
498-
if not is_valid_identifier(condition_id):
499-
raise ValueError(f"Invalid ID: {condition_id}")
500-
return condition_id
495+
def _validate_ids(cls, condition_ids):
496+
for condition_id in condition_ids:
497+
if not is_valid_identifier(condition_id):
498+
raise ValueError(f"Invalid ID: {condition_id}")
499+
return condition_ids
501500

502501

503502
class Experiment(BaseModel):
@@ -548,12 +547,20 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentTable:
548547

549548
experiments = []
550549
for experiment_id, cur_exp_df in df.groupby(C.EXPERIMENT_ID):
551-
periods = [
552-
ExperimentPeriod(
553-
time=row[C.TIME], condition_id=row[C.CONDITION_ID]
550+
periods = []
551+
for timepoint in cur_exp_df[C.TIME].unique():
552+
condition_ids = [
553+
cid
554+
for cid in cur_exp_df.loc[
555+
cur_exp_df[C.TIME] == timepoint, C.CONDITION_ID
556+
]
557+
if not pd.isna(cid)
558+
]
559+
periods.append(
560+
ExperimentPeriod(
561+
time=timepoint, condition_ids=condition_ids
562+
)
554563
)
555-
for _, row in cur_exp_df.iterrows()
556-
]
557564
experiments.append(Experiment(id=experiment_id, periods=periods))
558565

559566
return cls(experiments=experiments)
@@ -563,10 +570,12 @@ def to_df(self) -> pd.DataFrame:
563570
records = [
564571
{
565572
C.EXPERIMENT_ID: experiment.id,
566-
**period.model_dump(by_alias=True),
573+
C.TIME: period.time,
574+
C.CONDITION_ID: condition_id,
567575
}
568576
for experiment in self.experiments
569577
for period in experiment.periods
578+
for condition_id in period.condition_ids or [""]
570579
]
571580
return (
572581
pd.DataFrame(records)

petab/v2/lint.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Set
99
from dataclasses import dataclass, field
1010
from enum import IntEnum
11+
from itertools import chain
1112
from pathlib import Path
1213

1314
import pandas as pd
@@ -484,11 +485,14 @@ def run(self, problem: Problem) -> ValidationIssue | None:
484485
c.id for c in problem.condition_table.conditions
485486
}
486487
for experiment in problem.experiment_table.experiments:
487-
missing_conditions = {
488-
period.condition_id
489-
for period in experiment.periods
490-
if period.condition_id is not None
491-
} - available_conditions
488+
missing_conditions = (
489+
set(
490+
chain.from_iterable(
491+
period.condition_ids for period in experiment.periods
492+
)
493+
)
494+
- available_conditions
495+
)
492496
if missing_conditions:
493497
messages.append(
494498
f"Experiment {experiment.id} requires conditions that are "
@@ -499,6 +503,9 @@ def run(self, problem: Problem) -> ValidationIssue | None:
499503
return ValidationError("\n".join(messages))
500504

501505

506+
# TODO: Check that changes of simultaneously applied conditions don't intersect
507+
508+
502509
class CheckAllParametersPresentInParameterTable(ValidationTask):
503510
"""Ensure all required parameters are contained in the parameter table
504511
with no additional ones."""
@@ -646,12 +653,13 @@ class CheckUnusedConditions(ValidationTask):
646653
table."""
647654

648655
def run(self, problem: Problem) -> ValidationIssue | None:
649-
used_conditions = {
650-
p.condition_id
651-
for e in problem.experiment_table.experiments
652-
for p in e.periods
653-
if p.condition_id is not None
654-
}
656+
used_conditions = set(
657+
chain.from_iterable(
658+
p.condition_ids
659+
for e in problem.experiment_table.experiments
660+
for p in e.periods
661+
)
662+
)
655663
available_conditions = {
656664
c.id for c in problem.condition_table.conditions
657665
}

tests/v2/test_core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def test_experiment_add_periods():
3939
exp = Experiment(id="exp1")
4040
assert exp.periods == []
4141

42-
p1 = ExperimentPeriod(time=0, condition_id="p1")
43-
p2 = ExperimentPeriod(time=1, condition_id="p2")
44-
p3 = ExperimentPeriod(time=2, condition_id="p3")
42+
p1 = ExperimentPeriod(time=0, condition_ids=["p1"])
43+
p2 = ExperimentPeriod(time=1, condition_ids=["p2"])
44+
p3 = ExperimentPeriod(time=2, condition_ids=["p3"])
4545
exp += p1
4646
exp += p2
4747

@@ -201,22 +201,22 @@ def test_change():
201201

202202
def test_period():
203203
ExperimentPeriod(time=0)
204-
ExperimentPeriod(time=1, condition_id="p1")
205-
ExperimentPeriod(time="-inf", condition_id="p1")
204+
ExperimentPeriod(time=1, condition_ids=["p1"])
205+
ExperimentPeriod(time="-inf", condition_ids=["p1"])
206206

207207
assert (
208208
ExperimentPeriod(time="1", condition_id="p1", non_petab=1).non_petab
209209
== 1
210210
)
211211

212212
with pytest.raises(ValidationError, match="got inf"):
213-
ExperimentPeriod(time="inf", condition_id="p1")
213+
ExperimentPeriod(time="inf", condition_ids=["p1"])
214214

215215
with pytest.raises(ValidationError, match="Invalid ID"):
216-
ExperimentPeriod(time=1, condition_id="1_condition")
216+
ExperimentPeriod(time=1, condition_ids=["1_condition"])
217217

218218
with pytest.raises(ValidationError, match="type=missing"):
219-
ExperimentPeriod(condition_id="condition")
219+
ExperimentPeriod(condition_ids=["condition"])
220220

221221

222222
def test_parameter():

0 commit comments

Comments
 (0)