Skip to content

Commit 03c70d8

Browse files
committed
one more tiny change needed for trainer orchestrator
1 parent b91b088 commit 03c70d8

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

alphafold3_pytorch/configs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ def safe_deep_get(
2929
dotpath: str | List[str], # dotpath notation, so accessing {'a': {'b'': {'c': 1}}} would be "a.b.c"
3030
default = None
3131
):
32-
if isinstance(dotpath ,str):
32+
if isinstance(dotpath, str):
3333
dotpath = dotpath.split('.')
3434

3535
for key in dotpath:
36-
if key not in d:
36+
if (
37+
not isinstance(d, dict) or \
38+
key not in d
39+
):
3740
return default
3841

3942
d = d[key]
@@ -48,7 +51,7 @@ def yaml_config_path_to_dict(
4851
if isinstance(path, str):
4952
path = Path(path)
5053

51-
assert path.is_file()
54+
assert path.is_file(), f'cannot find {str(path)}'
5255

5356
with open(str(path), 'r') as f:
5457
maybe_config_dict = yaml.safe_load(f)
@@ -113,7 +116,7 @@ def create_instance_from_yaml_file(
113116
return af3_config.create_instance()
114117

115118
class TrainerConfig(BaseModelWithExtra):
116-
model: Alphafold3Config
119+
model: Alphafold3Config | None = None
117120
num_train_steps: int
118121
batch_size: int
119122
grad_accum_every: int
@@ -142,6 +145,7 @@ def from_yaml_file(
142145
def create_instance(
143146
self,
144147
dataset: Dataset,
148+
model: Alphafold3 | None = None,
145149
fabric: Fabric | None = None,
146150
test_dataset: Dataset | None = None,
147151
optimizer: Optimizer | None = None,
@@ -152,7 +156,12 @@ def create_instance(
152156

153157
trainer_kwargs = self.model_dump()
154158

155-
alphafold3 = self.model.create_instance()
159+
assert exists(self.model) ^ exists(model), 'either model is available on the trainer config, or passed in when creating the instance, but not both or neither'
160+
161+
if exists(self.model):
162+
alphafold3 = self.model.create_instance()
163+
else:
164+
alphafold3 = model
156165

157166
trainer_kwargs.update(dict(
158167
model = alphafold3,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.28"
3+
version = "0.1.29"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
DataLoader,
1515
Trainer,
1616
TrainerConfig,
17-
create_trainer_from_yaml
17+
create_trainer_from_yaml,
18+
create_alphafold3_from_yaml
1819
)
1920

2021
# mock dataset
@@ -188,3 +189,21 @@ def test_trainer_config():
188189
# take a single training step
189190

190191
trainer()
192+
193+
# test creating trainer without model, given when creating instance
194+
195+
def test_trainer_config_without_model():
196+
curr_dir = Path(__file__).parents[0]
197+
198+
af3_yaml_path = curr_dir / 'alphafold3.yaml'
199+
trainer_yaml_path = curr_dir / 'trainer_without_model.yaml'
200+
201+
alphafold3 = create_alphafold3_from_yaml(af3_yaml_path)
202+
203+
trainer = create_trainer_from_yaml(
204+
trainer_yaml_path,
205+
model = alphafold3,
206+
dataset = MockAtomDataset(16)
207+
)
208+
209+
assert isinstance(trainer, Trainer)

0 commit comments

Comments
 (0)