Skip to content

Commit 422efba

Browse files
committed
make one training step work with fabric + alphafold3 on cpu
1 parent 4532883 commit 422efba

File tree

6 files changed

+159
-15
lines changed

6 files changed

+159
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<img src="./alphafold3.png" width="450px"></img>
1+
<img src="./alphafold3.png" width="500px"></img>
22

33
## Alphafold 3 - Pytorch (wip)
44

alphafold3_pytorch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333

3434
from alphafold3_pytorch.trainer import (
35-
Trainer
35+
Trainer,
36+
Alphafold3Input
3637
)
3738

3839
__all__ = [
@@ -63,5 +64,6 @@
6364
ConfidenceHead,
6465
DistogramHead,
6566
Alphafold3,
67+
Alphafold3Input,
6668
Trainer
6769
]

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,7 +2853,8 @@ def forward(
28532853
pde_labels: Int['b n n'] | None = None,
28542854
plddt_labels: Int['b n'] | None = None,
28552855
resolved_labels: Int['b n'] | None = None,
2856-
return_loss_breakdown = False
2856+
return_loss_breakdown = False,
2857+
return_loss_if_possible: bool = True
28572858
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
28582859

28592860
atom_seq_len = atom_inputs.shape[-2]
@@ -3016,7 +3017,7 @@ def forward(
30163017

30173018
# if neither atom positions or any labels are passed in, sample a structure and return
30183019

3019-
if not return_loss:
3020+
if not return_loss_if_possible or not return_loss:
30203021
return self.edm.sample(
30213022
num_sample_steps = num_sample_steps,
30223023
atom_feats = atom_feats,

alphafold3_pytorch/trainer.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
from __future__ import annotations
22

33
from alphafold3_pytorch.alphafold3 import Alphafold3
4-
from alphafold3_pytorch.typing import typecheck
4+
5+
from typing import TypedDict
6+
from alphafold3_pytorch.typing import (
7+
typecheck,
8+
Int, Bool, Float
9+
)
510

611
import torch
712
from torch.optim import Adam, Optimizer
13+
from torch.utils.data import Dataset, DataLoader
814
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
915

1016
from ema_pytorch import EMA
1117

1218
from lightning import Fabric
1319

20+
# constants
21+
22+
@typecheck
23+
class Alphafold3Input(TypedDict):
24+
atom_inputs: Float['m dai']
25+
residue_atom_lens: Int['n 2']
26+
atompair_feats: Float['m m dap']
27+
additional_residue_feats: Float['n 10']
28+
templates: Float['t n n dt']
29+
template_mask: Bool['t'] | None
30+
msa: Float['s n dm']
31+
msa_mask: Bool['s'] | None
32+
atom_pos: Float['m 3'] | None
33+
residue_atom_indices: Int['n'] | None
34+
distance_labels: Int['n n'] | None
35+
pae_labels: Int['n n'] | None
36+
pde_labels: Int['n'] | None
37+
resolved_labels: Int['n'] | None
38+
1439
# helpers
1540

1641
def exists(val):
@@ -19,6 +44,11 @@ def exists(val):
1944
def default(v, d):
2045
return v if exists(v) else d
2146

47+
def cycle(dataloader: DataLoader):
48+
while True:
49+
for batch in dataloader:
50+
yield batch
51+
2252
def default_lambda_lr_fn(steps):
2353
# 1000 step warmup
2454

@@ -40,6 +70,10 @@ def __init__(
4070
self,
4171
model: Alphafold3,
4272
*,
73+
dataset: Dataset,
74+
num_train_steps: int,
75+
batch_size: int,
76+
grad_accum_every: int = 1,
4377
optimizer: Optimizer | None = None,
4478
scheduler: LRScheduler | None = None,
4579
ema_decay = 0.999,
@@ -69,12 +103,13 @@ def __init__(
69103

70104
# exponential moving average
71105

72-
self.ema_model = EMA(
73-
model,
74-
beta = ema_decay,
75-
include_online_model = False,
76-
**ema_kwargs
77-
)
106+
if self.is_main:
107+
self.ema_model = EMA(
108+
model,
109+
beta = ema_decay,
110+
include_online_model = False,
111+
**ema_kwargs
112+
)
78113

79114
# optimizer
80115

@@ -87,10 +122,19 @@ def __init__(
87122

88123
self.optimizer = optimizer
89124

125+
# data
126+
127+
self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
128+
129+
self.num_train_steps = num_train_steps
130+
self.grad_accum_every = grad_accum_every
131+
90132
# setup fabric
91133

92134
self.model, self.optimizer = fabric.setup(self.model, self.optimizer)
93135

136+
fabric.setup_dataloaders(self.dataloader)
137+
94138
# scheduler
95139

96140
if not exists(scheduler):
@@ -102,7 +146,35 @@ def __init__(
102146

103147
self.clip_grad_norm = clip_grad_norm
104148

149+
@property
150+
def is_main(self):
151+
return self.fabric.global_rank == 0
152+
105153
def __call__(
106154
self
107155
):
108-
pass
156+
dl = iter(self.dataloader)
157+
158+
steps = 0
159+
160+
while steps < self.num_train_steps:
161+
for _ in range(self.grad_accum_every):
162+
inputs = next(dl)
163+
164+
loss = self.model(**inputs)
165+
166+
self.fabric.backward(loss / self.grad_accum_every)
167+
168+
print(f'loss: {loss.item():.3f}')
169+
170+
self.optimizer.step()
171+
172+
if self.is_main:
173+
self.ema_model.update()
174+
175+
self.scheduler.step()
176+
self.optimizer.zero_grad()
177+
178+
steps += 1
179+
180+
print(f'training complete')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.38"
3+
version = "0.0.39"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,74 @@
11
import os
22
os.environ['TYPECHECK'] = 'True'
33

4-
import torch
54
import pytest
5+
import torch
6+
from torch.utils.data import Dataset
67

78
from alphafold3_pytorch import (
89
Alphafold3,
10+
Alphafold3Input,
911
Trainer
1012
)
1113

14+
# mock dataset
15+
16+
class AtomDataset(Dataset):
17+
def __init__(
18+
self,
19+
seq_len = 16,
20+
atoms_per_window = 27
21+
):
22+
self.seq_len = seq_len
23+
self.atom_seq_len = seq_len * atoms_per_window
24+
25+
def __len__(self):
26+
return 100
27+
28+
def __getitem__(self, idx):
29+
seq_len = self.seq_len
30+
atom_seq_len = self.atom_seq_len
31+
32+
atom_inputs = torch.randn(atom_seq_len, 77)
33+
residue_atom_lens = torch.randint(0, 27, (seq_len,))
34+
atompair_feats = torch.randn(atom_seq_len, atom_seq_len, 16)
35+
additional_residue_feats = torch.randn(seq_len, 10)
36+
37+
templates = torch.randn(2, seq_len, seq_len, 44)
38+
template_mask = torch.ones((2,)).bool()
39+
40+
msa = torch.randn(7, seq_len, 64)
41+
msa_mask = torch.ones((7,)).bool()
42+
43+
# required for training, but omitted on inference
44+
45+
atom_pos = torch.randn(atom_seq_len, 3)
46+
residue_atom_indices = torch.randint(0, 27, (seq_len,))
47+
48+
distance_labels = torch.randint(0, 37, (seq_len, seq_len))
49+
pae_labels = torch.randint(0, 64, (seq_len, seq_len))
50+
pde_labels = torch.randint(0, 64, (seq_len, seq_len))
51+
plddt_labels = torch.randint(0, 50, (seq_len,))
52+
resolved_labels = torch.randint(0, 2, (seq_len,))
53+
54+
return Alphafold3Input(
55+
atom_inputs = atom_inputs,
56+
residue_atom_lens = residue_atom_lens,
57+
atompair_feats = atompair_feats,
58+
additional_residue_feats = additional_residue_feats,
59+
templates = templates,
60+
template_mask = template_mask,
61+
msa = msa,
62+
msa_mask = msa_mask,
63+
atom_pos = atom_pos,
64+
residue_atom_indices = residue_atom_indices,
65+
distance_labels = distance_labels,
66+
pae_labels = pae_labels,
67+
pde_labels = pde_labels,
68+
plddt_labels = plddt_labels,
69+
resolved_labels = resolved_labels
70+
)
71+
1272
def test_trainer():
1373
alphafold3 = Alphafold3(
1474
dim_atom_inputs = 77,
@@ -33,6 +93,15 @@ def test_trainer():
3393
),
3494
)
3595

36-
trainer = Trainer(alphafold3)
96+
dataset = AtomDataset()
97+
98+
trainer = Trainer(
99+
alphafold3,
100+
dataset = dataset,
101+
accelerator = 'cpu',
102+
num_train_steps = 2,
103+
batch_size = 1,
104+
grad_accum_every = 2
105+
)
37106

38107
trainer()

0 commit comments

Comments
 (0)