Skip to content

Commit 2515ad2

Browse files
committed
add ability to instantiate trainer directly from yaml file, only needing to pass in datasets
1 parent 246f4b0 commit 2515ad2

File tree

10 files changed

+145
-39
lines changed

10 files changed

+145
-39
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
)
4040

4141
from alphafold3_pytorch.configs import (
42-
Alphafold3Config
42+
Alphafold3Config,
43+
TrainerConfig
4344
)
4445

4546
__all__ = [
@@ -72,5 +73,6 @@
7273
Alphafold3,
7374
Alphafold3Config,
7475
AtomInput,
75-
Trainer
76+
Trainer,
77+
TrainerConfig
7678
]

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Sequential,
1818
)
1919

20-
from typing import Literal, Tuple, NamedTuple, Callable
20+
from typing import List, Literal, Tuple, NamedTuple, Callable
2121

2222
from alphafold3_pytorch.typing import (
2323
Float,
@@ -2830,7 +2830,7 @@ def __init__(
28302830
dim_single = 384,
28312831
dim_pairwise = 128,
28322832
dim_token = 768,
2833-
distance_bins: List[float] = torch.linspace(3, 20, 38).tolist(),
2833+
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
28342834
ignore_index = -1,
28352835
num_dist_bins: int | None = None,
28362836
num_plddt_bins = 50,
@@ -3020,12 +3020,12 @@ def __init__(
30203020

30213021
# logit heads
30223022

3023-
distance_bins = Tensor(distance_bins)
3023+
distance_bins_tensor = Tensor(distance_bins)
30243024

3025-
self.register_buffer('distance_bins', distance_bins)
3026-
num_dist_bins = default(num_dist_bins, len(distance_bins))
3025+
self.register_buffer('distance_bins', distance_bins_tensor)
3026+
num_dist_bins = default(num_dist_bins, len(distance_bins_tensor))
30273027

3028-
assert len(distance_bins) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
3028+
assert len(distance_bins_tensor) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
30293029

30303030
self.distogram_head = DistogramHead(
30313031
dim_pairwise = dim_pairwise,

alphafold3_pytorch/configs.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
from __future__ import annotations
22

33
from alphafold3_pytorch.typing import typecheck
4+
from typing import Callable, List
5+
46
from alphafold3_pytorch.alphafold3 import Alphafold3
57

8+
from alphafold3_pytorch.trainer import (
9+
Trainer,
10+
Dataset,
11+
Fabric,
12+
Optimizer,
13+
LRScheduler
14+
)
15+
616
import yaml
717
from pathlib import Path
818

@@ -27,7 +37,7 @@ def yaml_config_path_to_dict(
2737
maybe_config_dict = yaml.safe_load(f)
2838

2939
assert exists(maybe_config_dict), f'unable to parse yaml config at {str(path)}'
30-
assert isinstance(maybe_config_dict, dict), f'yaml config file is not a dictionary'
40+
assert isinstance(maybe_config_dict, dict), 'yaml config file is not a dictionary'
3141

3242
return maybe_config_dict
3343

@@ -76,4 +86,59 @@ def create_instance_from_yaml_file(path: str | Path) -> Alphafold3:
7686
return af3_config.create_instance()
7787

7888
class TrainerConfig(BaseModelWithExtra):
79-
pass
89+
model: Alphafold3Config
90+
num_train_steps: int
91+
batch_size: int
92+
grad_accum_every: int
93+
valid_every: int
94+
ema_decay: float
95+
lr: float
96+
clip_grad_norm: int | float
97+
accelerator: str
98+
checkpoint_prefix: str
99+
checkpoint_every: int
100+
checkpoint_folder: str
101+
overwrite_checkpoints: bool
102+
103+
@staticmethod
104+
@typecheck
105+
def from_yaml_file(path: str | Path):
106+
config_dict = yaml_config_path_to_dict(path)
107+
return TrainerConfig(**config_dict)
108+
109+
def create_instance(
110+
self,
111+
dataset: Dataset,
112+
fabric: Fabric | None = None,
113+
test_dataset: Dataset | None = None,
114+
optimizer: Optimizer | None = None,
115+
scheduler: LRScheduler | None = None,
116+
valid_dataset: Dataset | None = None,
117+
map_dataset_input_fn: Callable | None = None,
118+
) -> Trainer:
119+
120+
trainer_kwargs = self.model_dump()
121+
122+
alphafold3 = self.model.create_instance()
123+
124+
trainer_kwargs.update(dict(
125+
model = alphafold3,
126+
dataset = dataset,
127+
fabric = fabric,
128+
test_dataset = test_dataset,
129+
optimizer = optimizer,
130+
scheduler = scheduler,
131+
valid_dataset = valid_dataset,
132+
map_dataset_input_fn = map_dataset_input_fn
133+
))
134+
135+
trainer = Trainer(**trainer_kwargs)
136+
return trainer
137+
138+
def create_instance_from_yaml_file(
139+
path: str | Path,
140+
**kwargs
141+
) -> Trainer:
142+
143+
trainer_config = TrainerConfig.from_yaml_file(path)
144+
return trainer_config.create_instance(**kwargs)

alphafold3_pytorch/pdb_dataset_curation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188

189189
# Helper functions
190190

191-
def exists(v: Any) -> bool:
191+
def exists(v) -> bool:
192192
"""Return `False` if `v` is `None`, else return `True`."""
193193
return v is not None
194194

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.23"
3+
version = "0.1.24"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/alphafold3.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ dim_input_embedder_token: 384
1010
dim_single: 384
1111
dim_pairwise: 128
1212
dim_token: 768
13-
distance_bins: 38
1413
ignore_index: -1
1514
num_dist_bins: null
1615
num_plddt_bins: 50

tests/test_af3.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import pytest
6+
from pathlib import Path
67

78
from alphafold3_pytorch import (
89
SmoothLDDTLoss,
@@ -22,6 +23,7 @@
2223
ConfidenceHead,
2324
DistogramHead,
2425
Alphafold3,
26+
Alphafold3Config
2527
)
2628

2729
from alphafold3_pytorch.alphafold3 import (
@@ -361,7 +363,7 @@ def test_confidence_head():
361363

362364
confidence_head = ConfidenceHead(
363365
dim_single_inputs = 77,
364-
atompair_dist_bins = torch.linspace(3, 20, 37),
366+
atompair_dist_bins = torch.linspace(3, 20, 37).tolist(),
365367
dim_single = 384,
366368
dim_pairwise = 128,
367369
)
@@ -565,3 +567,12 @@ def test_alphafold3_without_msa_and_templates():
565567
)
566568

567569
loss.backward()
570+
571+
# test creation from config
572+
573+
def test_alphafold3_config():
574+
curr_dir = Path(__file__).parents[0]
575+
af3_yaml = curr_dir / 'alphafold3.yaml'
576+
577+
alphafold3 = Alphafold3Config.create_instance_from_yaml_file(af3_yaml)
578+
assert isinstance(alphafold3, Alphafold3)

tests/test_config.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/test_trainer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
Alphafold3,
1313
AtomInput,
1414
DataLoader,
15-
Trainer
15+
Trainer,
16+
TrainerConfig
1617
)
1718

1819
# mock dataset
@@ -165,3 +166,20 @@ def test_trainer():
165166
# also allow for loading Alphafold3 directly from training ckpt
166167

167168
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training.pt')
169+
170+
# test creating trainer + alphafold3 from config
171+
172+
def test_trainer_config():
173+
curr_dir = Path(__file__).parents[0]
174+
trainer_yaml_path = curr_dir / 'trainer.yaml'
175+
176+
trainer = TrainerConfig.create_instance_from_yaml_file(
177+
trainer_yaml_path,
178+
dataset = MockAtomDataset(16)
179+
)
180+
181+
assert isinstance(trainer, Trainer)
182+
183+
# take a single training step
184+
185+
trainer()

tests/trainer.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
model:
3+
dim_atom_inputs: 77
4+
dim_template_feats: 44
5+
dim_template_model: 64
6+
atoms_per_window: 27
7+
dim_atom: 128
8+
dim_atompair_inputs: 5
9+
dim_atompair: 16
10+
dim_input_embedder_token: 384
11+
dim_single: 384
12+
dim_pairwise: 128
13+
dim_token: 768
14+
ignore_index: -1
15+
num_dist_bins: null
16+
num_plddt_bins: 50
17+
num_pde_bins: 64
18+
num_pae_bins: 64
19+
sigma_data: 16
20+
diffusion_num_augmentations: 4
21+
loss_confidence_weight: 0.0001
22+
loss_distogram_weight: 0.01
23+
loss_diffusion_weight: 4.
24+
num_train_steps: 1
25+
batch_size: 1
26+
grad_accum_every: 1
27+
valid_every: 1
28+
ema_decay: 0.999
29+
lr: 0.0001
30+
clip_grad_norm: 10.
31+
accelerator: cpu
32+
checkpoint_prefix: af3.ckpt.
33+
checkpoint_every: 1000
34+
checkpoint_folder: ./checkpoints
35+
overwrite_checkpoints: false

0 commit comments

Comments
 (0)