2323def exists (v ):
2424 return v is not None
2525
26+ @typecheck
27+ def safe_deep_get (
28+ d : dict ,
29+ dotpath : str | List [str ], # dotpath notation, so accessing {'a': {'b'': {'c': 1}}} would be "a.b.c"
30+ default = None
31+ ):
32+ if isinstance (dotpath ,str ):
33+ dotpath = dotpath .split ('.' )
34+
35+ for key in dotpath :
36+ if key not in d :
37+ return default
38+
39+ d = d [key ]
40+
41+ return d
42+
2643@typecheck
2744def yaml_config_path_to_dict (
2845 path : str | Path
29- ) -> dict | None :
46+ ) -> dict :
3047
3148 if isinstance (path , str ):
3249 path = Path (path )
@@ -73,16 +90,26 @@ class Alphafold3Config(BaseModelWithExtra):
7390
7491 @staticmethod
7592 @typecheck
76- def from_yaml_file (path : str | Path ):
93+ def from_yaml_file (
94+ path : str | Path ,
95+ dotpath : str | List [str ] = []
96+ ):
7797 config_dict = yaml_config_path_to_dict (path )
98+ config_dict = safe_deep_get (config_dict , dotpath )
99+ assert exists (config_dict ), f'config not found at path { "." .join (dotpath )} '
100+
78101 return Alphafold3Config (** config_dict )
79102
80103 def create_instance (self ) -> Alphafold3 :
81104 alphafold3 = Alphafold3 (** self .model_dump ())
82105 return alphafold3
83106
84- def create_instance_from_yaml_file (path : str | Path ) -> Alphafold3 :
85- af3_config = Alphafold3Config .from_yaml_file (path )
107+ def create_instance_from_yaml_file (
108+ path : str | Path ,
109+ dotpath : str | List [str ] = []
110+ ) -> Alphafold3 :
111+
112+ af3_config = Alphafold3Config .from_yaml_file (path , dotpath )
86113 return af3_config .create_instance ()
87114
88115class TrainerConfig (BaseModelWithExtra ):
@@ -102,8 +129,14 @@ class TrainerConfig(BaseModelWithExtra):
102129
103130 @staticmethod
104131 @typecheck
105- def from_yaml_file (path : str | Path ):
132+ def from_yaml_file (
133+ path : str | Path ,
134+ dotpath : str | List [str ] = []
135+ ):
106136 config_dict = yaml_config_path_to_dict (path )
137+ config_dict = safe_deep_get (config_dict , dotpath )
138+ assert exists (config_dict ), f'config not found at path { "." .join (dotpath )} '
139+
107140 return TrainerConfig (** config_dict )
108141
109142 def create_instance (
@@ -137,10 +170,11 @@ def create_instance(
137170
138171 def create_instance_from_yaml_file (
139172 path : str | Path ,
173+ dotpath : str | List [str ] = [],
140174 ** kwargs
141175 ) -> Trainer :
142176
143- trainer_config = TrainerConfig .from_yaml_file (path )
177+ trainer_config = TrainerConfig .from_yaml_file (path , dotpath )
144178 return trainer_config .create_instance (** kwargs )
145179
146180# convenience functions
0 commit comments