Skip to content

Commit 4a551a7

Browse files
authored
Store problem configuration in Problem (#326)
Introduces Problem.config which contains the info from the PEtab yaml file. Sometimes it is convenient to have the original filenames around. Pydantic gives more helpful error messages than `jsonschema` in case of incorrect inputs. Later on, this could replace `jsonschema` completely. Closes #324.
1 parent 45a3371 commit 4a551a7

File tree

2 files changed

+78
-27
lines changed

2 files changed

+78
-27
lines changed

petab/v1/problem.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from warnings import warn
1212

1313
import pandas as pd
14+
from pydantic import AnyUrl, BaseModel, Field, RootModel
1415

1516
from . import (
1617
conditions,
@@ -79,6 +80,7 @@ def __init__(
7980
observable_df: pd.DataFrame = None,
8081
mapping_df: pd.DataFrame = None,
8182
extensions_config: dict = None,
83+
config: ProblemConfig = None,
8284
):
8385
self.condition_df: pd.DataFrame | None = condition_df
8486
self.measurement_df: pd.DataFrame | None = measurement_df
@@ -113,6 +115,7 @@ def __init__(
113115

114116
self.model: Model | None = model
115117
self.extensions_config = extensions_config or {}
118+
self.config = config
116119

117120
def __getattr__(self, name):
118121
# For backward-compatibility, allow access to SBML model related
@@ -262,10 +265,14 @@ def from_yaml(
262265
yaml_config: PEtab configuration as dictionary or YAML file name
263266
base_path: Base directory or URL to resolve relative paths
264267
"""
268+
# path to the yaml file
269+
filepath = None
270+
265271
if isinstance(yaml_config, Path):
266272
yaml_config = str(yaml_config)
267273

268274
if isinstance(yaml_config, str):
275+
filepath = yaml_config
269276
if base_path is None:
270277
base_path = get_path_prefix(yaml_config)
271278
yaml_config = yaml.load_yaml(yaml_config)
@@ -297,59 +304,58 @@ def get_path(filename):
297304
DeprecationWarning,
298305
stacklevel=2,
299306
)
307+
config = ProblemConfig(
308+
**yaml_config, base_path=base_path, filepath=filepath
309+
)
310+
problem0 = config.problems[0]
311+
# currently required for handling PEtab v2 in here
312+
problem0_ = yaml_config["problems"][0]
300313

301-
problem0 = yaml_config["problems"][0]
302-
303-
if isinstance(yaml_config[PARAMETER_FILE], list):
314+
if isinstance(config.parameter_file, list):
304315
parameter_df = parameters.get_parameter_df(
305-
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
316+
[get_path(f) for f in config.parameter_file]
306317
)
307318
else:
308319
parameter_df = (
309-
parameters.get_parameter_df(
310-
get_path(yaml_config[PARAMETER_FILE])
311-
)
312-
if yaml_config[PARAMETER_FILE]
320+
parameters.get_parameter_df(get_path(config.parameter_file))
321+
if config.parameter_file
313322
else None
314323
)
315-
316-
if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]:
317-
if len(problem0[SBML_FILES]) > 1:
324+
if config.format_version.root in [1, "1", "1.0.0"]:
325+
if len(problem0.sbml_files) > 1:
318326
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
319327
raise NotImplementedError(
320328
"Support for multiple models is not yet implemented."
321329
)
322330

323331
model = (
324332
model_factory(
325-
get_path(problem0[SBML_FILES][0]),
333+
get_path(problem0.sbml_files[0]),
326334
MODEL_TYPE_SBML,
327335
model_id=None,
328336
)
329-
if problem0[SBML_FILES]
337+
if problem0.sbml_files
330338
else None
331339
)
332340
else:
333-
if len(problem0[MODEL_FILES]) > 1:
341+
if len(problem0_[MODEL_FILES]) > 1:
334342
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
335343
raise NotImplementedError(
336344
"Support for multiple models is not yet implemented."
337345
)
338-
if not problem0[MODEL_FILES]:
346+
if not problem0_[MODEL_FILES]:
339347
model = None
340348
else:
341349
model_id, model_info = next(
342-
iter(problem0[MODEL_FILES].items())
350+
iter(problem0_[MODEL_FILES].items())
343351
)
344352
model = model_factory(
345353
get_path(model_info[MODEL_LOCATION]),
346354
model_info[MODEL_LANGUAGE],
347355
model_id=model_id,
348356
)
349357

350-
measurement_files = [
351-
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
352-
]
358+
measurement_files = [get_path(f) for f in problem0.measurement_files]
353359
# If there are multiple tables, we will merge them
354360
measurement_df = (
355361
core.concat_tables(
@@ -359,9 +365,7 @@ def get_path(filename):
359365
else None
360366
)
361367

362-
condition_files = [
363-
get_path(f) for f in problem0.get(CONDITION_FILES, [])
364-
]
368+
condition_files = [get_path(f) for f in problem0.condition_files]
365369
# If there are multiple tables, we will merge them
366370
condition_df = (
367371
core.concat_tables(condition_files, conditions.get_condition_df)
@@ -370,7 +374,7 @@ def get_path(filename):
370374
)
371375

372376
visualization_files = [
373-
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
377+
get_path(f) for f in problem0.visualization_files
374378
]
375379
# If there are multiple tables, we will merge them
376380
visualization_df = (
@@ -379,17 +383,15 @@ def get_path(filename):
379383
else None
380384
)
381385

382-
observable_files = [
383-
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
384-
]
386+
observable_files = [get_path(f) for f in problem0.observable_files]
385387
# If there are multiple tables, we will merge them
386388
observable_df = (
387389
core.concat_tables(observable_files, observables.get_observable_df)
388390
if observable_files
389391
else None
390392
)
391393

392-
mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
394+
mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])]
393395
# If there are multiple tables, we will merge them
394396
mapping_df = (
395397
core.concat_tables(mapping_files, mapping.get_mapping_df)
@@ -406,6 +408,7 @@ def get_path(filename):
406408
visualization_df=visualization_df,
407409
mapping_df=mapping_df,
408410
extensions_config=yaml_config.get(EXTENSIONS, {}),
411+
config=config,
409412
)
410413

411414
@staticmethod
@@ -1184,3 +1187,50 @@ def add_measurement(
11841187
if self.measurement_df is not None
11851188
else tmp_df
11861189
)
1190+
1191+
1192+
class VersionNumber(RootModel):
1193+
root: str | int
1194+
1195+
1196+
class ListOfFiles(RootModel):
1197+
"""List of files."""
1198+
1199+
root: list[str | AnyUrl] = Field(..., description="List of files.")
1200+
1201+
def __iter__(self):
1202+
return iter(self.root)
1203+
1204+
def __len__(self):
1205+
return len(self.root)
1206+
1207+
def __getitem__(self, index):
1208+
return self.root[index]
1209+
1210+
1211+
class SubProblem(BaseModel):
1212+
"""A `problems` object in the PEtab problem configuration."""
1213+
1214+
sbml_files: ListOfFiles = []
1215+
measurement_files: ListOfFiles = []
1216+
condition_files: ListOfFiles = []
1217+
observable_files: ListOfFiles = []
1218+
visualization_files: ListOfFiles = []
1219+
1220+
1221+
class ProblemConfig(BaseModel):
1222+
"""The PEtab problem configuration."""
1223+
1224+
filepath: str | AnyUrl | None = Field(
1225+
None,
1226+
description="The path to the PEtab problem configuration.",
1227+
exclude=True,
1228+
)
1229+
base_path: str | AnyUrl | None = Field(
1230+
None,
1231+
description="The base path to resolve relative paths.",
1232+
exclude=True,
1233+
)
1234+
format_version: VersionNumber = 1
1235+
parameter_file: str | AnyUrl | None = None
1236+
problems: list[SubProblem] = []

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"pyyaml",
2323
"jsonschema",
2424
"antlr4-python3-runtime==4.13.1",
25+
"pydantic>=2.10",
2526
]
2627
license = {text = "MIT License"}
2728
authors = [

0 commit comments

Comments
 (0)