Skip to content

Commit d1f692f

Browse files
committed
v2: Allow applying multiple conditions simultaneously
1 parent 4238588 commit d1f692f

File tree

5 files changed

+98
-39
lines changed

5 files changed

+98
-39
lines changed

petab/v2/core.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def free_symbols(self) -> set[sp.Symbol]:
475475

476476
class ExperimentPeriod(BaseModel):
477477
"""A period of a timecourse or experiment defined by a start time
478-
and a condition ID.
478+
and a list of condition IDs.
479479
480480
This corresponds to a row of the PEtab experiments table.
481481
"""
@@ -484,20 +484,19 @@ class ExperimentPeriod(BaseModel):
484484
time: Annotated[float, AfterValidator(_is_finite_or_neg_inf)] = Field(
485485
alias=C.TIME
486486
)
487-
#: The ID of the condition to be applied at the start time.
488-
condition_id: str | None = Field(alias=C.CONDITION_ID, default=None)
487+
#: The IDs of the conditions to be applied at the start time.
488+
condition_ids: list[str] = []
489489

490490
#: :meta private:
491491
model_config = ConfigDict(populate_by_name=True, extra="allow")
492492

493-
@field_validator("condition_id", mode="before")
493+
@field_validator("condition_ids", mode="before")
494494
@classmethod
495-
def _validate_id(cls, condition_id):
496-
if pd.isna(condition_id) or not condition_id:
497-
return None
498-
if not is_valid_identifier(condition_id):
499-
raise ValueError(f"Invalid ID: {condition_id}")
500-
return condition_id
495+
def _validate_ids(cls, condition_ids):
496+
for condition_id in condition_ids:
497+
if not is_valid_identifier(condition_id):
498+
raise ValueError(f"Invalid ID: {condition_id}")
499+
return condition_ids
501500

502501

503502
class Experiment(BaseModel):
@@ -548,12 +547,20 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentTable:
548547

549548
experiments = []
550549
for experiment_id, cur_exp_df in df.groupby(C.EXPERIMENT_ID):
551-
periods = [
552-
ExperimentPeriod(
553-
time=row[C.TIME], condition_id=row[C.CONDITION_ID]
550+
periods = []
551+
for timepoint in cur_exp_df[C.TIME].unique():
552+
condition_ids = [
553+
cid
554+
for cid in cur_exp_df.loc[
555+
cur_exp_df[C.TIME] == timepoint, C.CONDITION_ID
556+
]
557+
if not pd.isna(cid)
558+
]
559+
periods.append(
560+
ExperimentPeriod(
561+
time=timepoint, condition_ids=condition_ids
562+
)
554563
)
555-
for _, row in cur_exp_df.iterrows()
556-
]
557564
experiments.append(Experiment(id=experiment_id, periods=periods))
558565

559566
return cls(experiments=experiments)
@@ -563,10 +570,12 @@ def to_df(self) -> pd.DataFrame:
563570
records = [
564571
{
565572
C.EXPERIMENT_ID: experiment.id,
566-
**period.model_dump(by_alias=True),
573+
C.TIME: period.time,
574+
C.CONDITION_ID: condition_id,
567575
}
568576
for experiment in self.experiments
569577
for period in experiment.periods
578+
for condition_id in period.condition_ids or [""]
570579
]
571580
return (
572581
pd.DataFrame(records)

petab/v2/lint.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Set
99
from dataclasses import dataclass, field
1010
from enum import IntEnum
11+
from itertools import chain
1112
from pathlib import Path
1213

1314
import pandas as pd
@@ -373,8 +374,10 @@ class CheckValidConditionTargets(ValidationTask):
373374
"""Check that all condition table targets are valid."""
374375

375376
def run(self, problem: Problem) -> ValidationIssue | None:
376-
allowed_targets = set(
377-
problem.model.get_valid_ids_for_condition_table()
377+
allowed_targets = (
378+
set(problem.model.get_valid_ids_for_condition_table())
379+
if problem.model
380+
else set()
378381
)
379382
allowed_targets |= set(get_output_parameters(problem))
380383
allowed_targets |= {
@@ -394,6 +397,28 @@ def run(self, problem: Problem) -> ValidationIssue | None:
394397
f"Condition table contains invalid targets: {invalid}"
395398
)
396399

400+
# Check that changes of simultaneously applied conditions don't
401+
# intersect
402+
for experiment in problem.experiment_table.experiments:
403+
for period in experiment.periods:
404+
if not period.condition_ids:
405+
continue
406+
period_targets = set()
407+
for condition_id in period.condition_ids:
408+
condition_targets = {
409+
change.target_id
410+
for cond in problem.condition_table.conditions
411+
if cond.id == condition_id
412+
for change in cond.changes
413+
}
414+
if invalid := (period_targets & condition_targets):
415+
return ValidationError(
416+
"Simultaneously applied conditions for experiment "
417+
f"{experiment.id} have overlapping targets "
418+
f"{invalid} at time {period.time}."
419+
)
420+
period_targets |= condition_targets
421+
397422

398423
class CheckUniquePrimaryKeys(ValidationTask):
399424
"""Check that all primary keys are unique."""
@@ -484,11 +509,14 @@ def run(self, problem: Problem) -> ValidationIssue | None:
484509
c.id for c in problem.condition_table.conditions
485510
}
486511
for experiment in problem.experiment_table.experiments:
487-
missing_conditions = {
488-
period.condition_id
489-
for period in experiment.periods
490-
if period.condition_id is not None
491-
} - available_conditions
512+
missing_conditions = (
513+
set(
514+
chain.from_iterable(
515+
period.condition_ids for period in experiment.periods
516+
)
517+
)
518+
- available_conditions
519+
)
492520
if missing_conditions:
493521
messages.append(
494522
f"Experiment {experiment.id} requires conditions that are "
@@ -646,12 +674,13 @@ class CheckUnusedConditions(ValidationTask):
646674
table."""
647675

648676
def run(self, problem: Problem) -> ValidationIssue | None:
649-
used_conditions = {
650-
p.condition_id
651-
for e in problem.experiment_table.experiments
652-
for p in e.periods
653-
if p.condition_id is not None
654-
}
677+
used_conditions = set(
678+
chain.from_iterable(
679+
p.condition_ids
680+
for e in problem.experiment_table.experiments
681+
for p in e.periods
682+
)
683+
)
655684
available_conditions = {
656685
c.id for c in problem.condition_table.conditions
657686
}

petab/v2/problem.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,12 @@ def add_experiment(self, id_: str, *args):
10621062
)
10631063

10641064
periods = [
1065-
core.ExperimentPeriod(time=args[i], condition_id=args[i + 1])
1065+
core.ExperimentPeriod(
1066+
time=args[i],
1067+
condition_ids=[cond]
1068+
if isinstance((cond := args[i + 1]), str)
1069+
else cond,
1070+
)
10661071
for i in range(0, len(args), 2)
10671072
]
10681073

tests/v2/test_core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def test_experiment_add_periods():
3939
exp = Experiment(id="exp1")
4040
assert exp.periods == []
4141

42-
p1 = ExperimentPeriod(time=0, condition_id="p1")
43-
p2 = ExperimentPeriod(time=1, condition_id="p2")
44-
p3 = ExperimentPeriod(time=2, condition_id="p3")
42+
p1 = ExperimentPeriod(time=0, condition_ids=["p1"])
43+
p2 = ExperimentPeriod(time=1, condition_ids=["p2"])
44+
p3 = ExperimentPeriod(time=2, condition_ids=["p3"])
4545
exp += p1
4646
exp += p2
4747

@@ -201,22 +201,22 @@ def test_change():
201201

202202
def test_period():
203203
ExperimentPeriod(time=0)
204-
ExperimentPeriod(time=1, condition_id="p1")
205-
ExperimentPeriod(time="-inf", condition_id="p1")
204+
ExperimentPeriod(time=1, condition_ids=["p1"])
205+
ExperimentPeriod(time="-inf", condition_ids=["p1"])
206206

207207
assert (
208208
ExperimentPeriod(time="1", condition_id="p1", non_petab=1).non_petab
209209
== 1
210210
)
211211

212212
with pytest.raises(ValidationError, match="got inf"):
213-
ExperimentPeriod(time="inf", condition_id="p1")
213+
ExperimentPeriod(time="inf", condition_ids=["p1"])
214214

215215
with pytest.raises(ValidationError, match="Invalid ID"):
216-
ExperimentPeriod(time=1, condition_id="1_condition")
216+
ExperimentPeriod(time=1, condition_ids=["1_condition"])
217217

218218
with pytest.raises(ValidationError, match="type=missing"):
219-
ExperimentPeriod(condition_id="condition")
219+
ExperimentPeriod(condition_ids=["condition"])
220220

221221

222222
def test_parameter():

tests/v2/test_lint.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from copy import deepcopy
44

55
from petab.v2 import Problem
6-
from petab.v2.C import *
76
from petab.v2.lint import *
7+
from petab.v2.models.sbml_model import SbmlModel
88

99

1010
def test_check_experiments():
@@ -21,3 +21,19 @@ def test_check_experiments():
2121
tmp_problem = deepcopy(problem)
2222
tmp_problem["e1"].periods[0].time = tmp_problem["e1"].periods[1].time
2323
assert check.run(tmp_problem) is not None
24+
25+
26+
def test_check_incompatible_targets():
27+
"""Multiple conditions with overlapping targets cannot be applied
28+
at the same time."""
29+
problem = Problem()
30+
problem.model = SbmlModel.from_antimony("p1 = 1; p2 = 2")
31+
problem.add_experiment("e1", 0, "c1", 1, "c2")
32+
problem.add_condition("c1", p1="1")
33+
problem.add_condition("c2", p1="2", p2="2")
34+
check = CheckValidConditionTargets()
35+
assert check.run(problem) is None
36+
37+
problem["e1"].periods[0].condition_ids.append("c2")
38+
assert (error := check.run(problem)) is not None
39+
assert "overlapping targets {'p1'}" in error.message

0 commit comments

Comments
 (0)