Skip to content

Commit b91b088

Browse files
committed
able to instantiate a config from a deeply nested key in the config yml
1 parent f66be08 commit b91b088

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

alphafold3_pytorch/configs.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,27 @@
2323
def exists(v):
2424
return v is not None
2525

26+
@typecheck
27+
def safe_deep_get(
28+
d: dict,
29+
dotpath: str | List[str], # dotpath notation, so accessing {'a': {'b'': {'c': 1}}} would be "a.b.c"
30+
default = None
31+
):
32+
if isinstance(dotpath ,str):
33+
dotpath = dotpath.split('.')
34+
35+
for key in dotpath:
36+
if key not in d:
37+
return default
38+
39+
d = d[key]
40+
41+
return d
42+
2643
@typecheck
2744
def yaml_config_path_to_dict(
2845
path: str | Path
29-
) -> dict | None:
46+
) -> dict:
3047

3148
if isinstance(path, str):
3249
path = Path(path)
@@ -73,16 +90,26 @@ class Alphafold3Config(BaseModelWithExtra):
7390

7491
@staticmethod
7592
@typecheck
76-
def from_yaml_file(path: str | Path):
93+
def from_yaml_file(
94+
path: str | Path,
95+
dotpath: str | List[str] = []
96+
):
7797
config_dict = yaml_config_path_to_dict(path)
98+
config_dict = safe_deep_get(config_dict, dotpath)
99+
assert exists(config_dict), f'config not found at path {".".join(dotpath)}'
100+
78101
return Alphafold3Config(**config_dict)
79102

80103
def create_instance(self) -> Alphafold3:
81104
alphafold3 = Alphafold3(**self.model_dump())
82105
return alphafold3
83106

84-
def create_instance_from_yaml_file(path: str | Path) -> Alphafold3:
85-
af3_config = Alphafold3Config.from_yaml_file(path)
107+
def create_instance_from_yaml_file(
108+
path: str | Path,
109+
dotpath: str | List[str] = []
110+
) -> Alphafold3:
111+
112+
af3_config = Alphafold3Config.from_yaml_file(path, dotpath)
86113
return af3_config.create_instance()
87114

88115
class TrainerConfig(BaseModelWithExtra):
@@ -102,8 +129,14 @@ class TrainerConfig(BaseModelWithExtra):
102129

103130
@staticmethod
104131
@typecheck
105-
def from_yaml_file(path: str | Path):
132+
def from_yaml_file(
133+
path: str | Path,
134+
dotpath: str | List[str] = []
135+
):
106136
config_dict = yaml_config_path_to_dict(path)
137+
config_dict = safe_deep_get(config_dict, dotpath)
138+
assert exists(config_dict), f'config not found at path {".".join(dotpath)}'
139+
107140
return TrainerConfig(**config_dict)
108141

109142
def create_instance(
@@ -137,10 +170,11 @@ def create_instance(
137170

138171
def create_instance_from_yaml_file(
139172
path: str | Path,
173+
dotpath: str | List[str] = [],
140174
**kwargs
141175
) -> Trainer:
142176

143-
trainer_config = TrainerConfig.from_yaml_file(path)
177+
trainer_config = TrainerConfig.from_yaml_file(path, dotpath)
144178
return trainer_config.create_instance(**kwargs)
145179

146180
# convenience functions

tests/test_af3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,10 @@ def test_alphafold3_without_msa_and_templates():
577577
def test_alphafold3_config():
578578
curr_dir = Path(__file__).parents[0]
579579
af3_yaml = curr_dir / 'alphafold3.yaml'
580+
trainer_yml = curr_dir / 'trainer.yaml'
580581

581582
alphafold3 = create_alphafold3_from_yaml(af3_yaml)
582583
assert isinstance(alphafold3, Alphafold3)
584+
585+
alphafold3_from_trainer_yml = create_alphafold3_from_yaml(trainer_yml, 'model')
586+
assert isinstance(alphafold3_from_trainer_yml, Alphafold3)

0 commit comments

Comments
 (0)