1010from warnings import warn
1111
1212import pandas as pd
13+ from pydantic import AnyUrl , BaseModel , Field , RootModel
1314
1415from . import (
1516 conditions ,
@@ -78,6 +79,7 @@ def __init__(
7879 observable_df : pd .DataFrame = None ,
7980 mapping_df : pd .DataFrame = None ,
8081 extensions_config : dict = None ,
82+ config : ProblemConfig = None ,
8183 ):
8284 self .condition_df : pd .DataFrame | None = condition_df
8385 self .measurement_df : pd .DataFrame | None = measurement_df
@@ -112,6 +114,7 @@ def __init__(
112114
113115 self .model : Model | None = model
114116 self .extensions_config = extensions_config or {}
117+ self .config = config
115118
116119 def __getattr__ (self , name ):
117120 # For backward-compatibility, allow access to SBML model related
@@ -261,10 +264,14 @@ def from_yaml(
261264 yaml_config: PEtab configuration as dictionary or YAML file name
262265 base_path: Base directory or URL to resolve relative paths
263266 """
267+ # path to the yaml file
268+ filepath = None
269+
264270 if isinstance (yaml_config , Path ):
265271 yaml_config = str (yaml_config )
266272
267273 if isinstance (yaml_config , str ):
274+ filepath = yaml_config
268275 if base_path is None :
269276 base_path = get_path_prefix (yaml_config )
270277 yaml_config = yaml .load_yaml (yaml_config )
@@ -296,59 +303,58 @@ def get_path(filename):
296303 DeprecationWarning ,
297304 stacklevel = 2 ,
298305 )
306+ config = ProblemConfig (
307+ ** yaml_config , base_path = base_path , filepath = filepath
308+ )
309+ problem0 = config .problems [0 ]
310+ # currently required for handling PEtab v2 in here
311+ problem0_ = yaml_config ["problems" ][0 ]
299312
300- problem0 = yaml_config ["problems" ][0 ]
301-
302- if isinstance (yaml_config [PARAMETER_FILE ], list ):
313+ if isinstance (config .parameter_file , list ):
303314 parameter_df = parameters .get_parameter_df (
304- [get_path (f ) for f in yaml_config [ PARAMETER_FILE ] ]
315+ [get_path (f ) for f in config . parameter_file ]
305316 )
306317 else :
307318 parameter_df = (
308- parameters .get_parameter_df (
309- get_path (yaml_config [PARAMETER_FILE ])
310- )
311- if yaml_config [PARAMETER_FILE ]
319+ parameters .get_parameter_df (get_path (config .parameter_file ))
320+ if config .parameter_file
312321 else None
313322 )
314-
315- if yaml_config [FORMAT_VERSION ] in [1 , "1" , "1.0.0" ]:
316- if len (problem0 [SBML_FILES ]) > 1 :
323+ if config .format_version .root in [1 , "1" , "1.0.0" ]:
324+ if len (problem0 .sbml_files ) > 1 :
317325 # TODO https://github.com/PEtab-dev/libpetab-python/issues/6
318326 raise NotImplementedError (
319327 "Support for multiple models is not yet implemented."
320328 )
321329
322330 model = (
323331 model_factory (
324- get_path (problem0 [ SBML_FILES ] [0 ]),
332+ get_path (problem0 . sbml_files [0 ]),
325333 MODEL_TYPE_SBML ,
326334 model_id = None ,
327335 )
328- if problem0 [ SBML_FILES ]
336+ if problem0 . sbml_files
329337 else None
330338 )
331339 else :
332- if len (problem0 [MODEL_FILES ]) > 1 :
340+ if len (problem0_ [MODEL_FILES ]) > 1 :
333341 # TODO https://github.com/PEtab-dev/libpetab-python/issues/6
334342 raise NotImplementedError (
335343 "Support for multiple models is not yet implemented."
336344 )
337- if not problem0 [MODEL_FILES ]:
345+ if not problem0_ [MODEL_FILES ]:
338346 model = None
339347 else :
340348 model_id , model_info = next (
341- iter (problem0 [MODEL_FILES ].items ())
349+ iter (problem0_ [MODEL_FILES ].items ())
342350 )
343351 model = model_factory (
344352 get_path (model_info [MODEL_LOCATION ]),
345353 model_info [MODEL_LANGUAGE ],
346354 model_id = model_id ,
347355 )
348356
349- measurement_files = [
350- get_path (f ) for f in problem0 .get (MEASUREMENT_FILES , [])
351- ]
357+ measurement_files = [get_path (f ) for f in problem0 .measurement_files ]
352358 # If there are multiple tables, we will merge them
353359 measurement_df = (
354360 core .concat_tables (
@@ -358,9 +364,7 @@ def get_path(filename):
358364 else None
359365 )
360366
361- condition_files = [
362- get_path (f ) for f in problem0 .get (CONDITION_FILES , [])
363- ]
367+ condition_files = [get_path (f ) for f in problem0 .condition_files ]
364368 # If there are multiple tables, we will merge them
365369 condition_df = (
366370 core .concat_tables (condition_files , conditions .get_condition_df )
@@ -369,7 +373,7 @@ def get_path(filename):
369373 )
370374
371375 visualization_files = [
372- get_path (f ) for f in problem0 .get ( VISUALIZATION_FILES , [])
376+ get_path (f ) for f in problem0 .visualization_files
373377 ]
374378 # If there are multiple tables, we will merge them
375379 visualization_df = (
@@ -378,17 +382,15 @@ def get_path(filename):
378382 else None
379383 )
380384
381- observable_files = [
382- get_path (f ) for f in problem0 .get (OBSERVABLE_FILES , [])
383- ]
385+ observable_files = [get_path (f ) for f in problem0 .observable_files ]
384386 # If there are multiple tables, we will merge them
385387 observable_df = (
386388 core .concat_tables (observable_files , observables .get_observable_df )
387389 if observable_files
388390 else None
389391 )
390392
391- mapping_files = [get_path (f ) for f in problem0 .get (MAPPING_FILES , [])]
393+ mapping_files = [get_path (f ) for f in problem0_ .get (MAPPING_FILES , [])]
392394 # If there are multiple tables, we will merge them
393395 mapping_df = (
394396 core .concat_tables (mapping_files , mapping .get_mapping_df )
@@ -405,6 +407,7 @@ def get_path(filename):
405407 visualization_df = visualization_df ,
406408 mapping_df = mapping_df ,
407409 extensions_config = yaml_config .get (EXTENSIONS , {}),
410+ config = config ,
408411 )
409412
410413 @staticmethod
@@ -1005,3 +1008,50 @@ def n_priors(self) -> int:
10051008 return 0
10061009
10071010 return self .parameter_df [OBJECTIVE_PRIOR_PARAMETERS ].notna ().sum ()
1011+
1012+
1013+ class VersionNumber (RootModel ):
1014+ root : str | int
1015+
1016+
1017+ class ListOfFiles (RootModel ):
1018+ """List of files."""
1019+
1020+ root : list [str | AnyUrl ] = Field (..., description = "List of files." )
1021+
1022+ def __iter__ (self ):
1023+ return iter (self .root )
1024+
1025+ def __len__ (self ):
1026+ return len (self .root )
1027+
1028+ def __getitem__ (self , index ):
1029+ return self .root [index ]
1030+
1031+
1032+ class SubProblem (BaseModel ):
1033+ """A `problems` object in the PEtab problem configuration."""
1034+
1035+ sbml_files : ListOfFiles = []
1036+ measurement_files : ListOfFiles = []
1037+ condition_files : ListOfFiles = []
1038+ observable_files : ListOfFiles = []
1039+ visualization_files : ListOfFiles = []
1040+
1041+
1042+ class ProblemConfig (BaseModel ):
1043+ """The PEtab problem configuration."""
1044+
1045+ filepath : str | AnyUrl | None = Field (
1046+ None ,
1047+ description = "The path to the PEtab problem configuration." ,
1048+ exclude = True ,
1049+ )
1050+ base_path : str | AnyUrl | None = Field (
1051+ None ,
1052+ description = "The base path to resolve relative paths." ,
1053+ exclude = True ,
1054+ )
1055+ format_version : VersionNumber = 1
1056+ parameter_file : str | AnyUrl | None = None
1057+ problems : list [SubProblem ] = []
0 commit comments