Skip to content

Commit 19a934a

Browse files
committed
bridge the PDBInput to the Trainer through the PDBDataset and demonstrate end2end training steps can be done
1 parent 2aa9ed3 commit 19a934a

File tree

4 files changed

+175
-4
lines changed

4 files changed

+175
-4
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
MoleculeInput,
4141
Alphafold3Input,
4242
PDBInput,
43+
PDBDataset,
4344
maybe_transform_to_atom_input,
4445
maybe_transform_to_atom_inputs
4546
)

alphafold3_pytorch/inputs.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33
import copy
44
import json
55
import os
6+
from pathlib import Path
67
from collections import defaultdict
78
from dataclasses import asdict, dataclass, field
89
from functools import partial
910
from itertools import groupby
1011
from typing import Any, Callable, List, Set, Tuple, Type
1112

1213
import einx
14+
1315
import numpy as np
1416
import torch
17+
from torch import tensor
18+
from torch.utils.data import Dataset
1519
import torch.nn.functional as F
20+
from torch.nn.utils.rnn import pad_sequence
21+
1622
from loguru import logger
1723
from pdbeccdutils.core import ccd_reader
24+
1825
from rdkit import Chem
1926
from rdkit.Chem import AllChem, rdDetermineBonds
2027
from rdkit.Chem.rdchem import Atom, Mol
2128
from rdkit.Geometry import Point3D
22-
from torch import tensor
23-
from torch.nn.utils.rnn import pad_sequence
2429

2530
from alphafold3_pytorch.common import amino_acid_constants, dna_constants, rna_constants
2631
from alphafold3_pytorch.common.biomolecule import (
@@ -1864,6 +1869,41 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
18641869
return molecule_input
18651870

18661871

1872+
# PDB Dataset
1873+
1874+
class PDBDataset(Dataset):
1875+
def __init__(
1876+
self,
1877+
folder: str | Path,
1878+
training: bool | None = None, # extra training flag placed by Alex on PDBInput
1879+
**pdb_input_kwargs
1880+
):
1881+
if isinstance(folder, str):
1882+
folder = Path(folder)
1883+
1884+
assert folder.exists() and folder.is_dir()
1885+
1886+
self.files = [*folder.glob('*.cif')]
1887+
self.pdb_input_kwargs = pdb_input_kwargs
1888+
self.training = training
1889+
1890+
def __len__(self):
1891+
return len(self.files)
1892+
1893+
def __getitem__(self, idx):
1894+
1895+
kwargs = self.pdb_input_kwargs
1896+
1897+
if exists(self.training):
1898+
kwargs = {**kwargs, 'training': self.training}
1899+
1900+
pdb_input = PDBInput(
1901+
str(self.files[idx]),
1902+
**kwargs
1903+
)
1904+
1905+
return pdb_input
1906+
18671907
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
18681908
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
18691909

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.15"
3+
version = "0.2.16"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from alphafold3_pytorch import (
1414
Alphafold3,
15+
PDBDataset,
1516
AtomInput,
1617
DataLoader,
1718
Trainer,
@@ -106,7 +107,7 @@ def remove_test_folders():
106107
yield
107108
shutil.rmtree('./test-folder')
108109

109-
def test_trainer(remove_test_folders):
110+
def test_trainer_with_mock_atom_input(remove_test_folders):
110111

111112
alphafold3 = Alphafold3(
112113
dim_atom_inputs = 77,
@@ -204,6 +205,135 @@ def test_trainer(remove_test_folders):
204205

205206
alphafold3 = Alphafold3.init_and_load('./test-folder/nested/folder2/training.pt')
206207

208+
# testing trainer with pdb inputs
209+
210+
@pytest.fixture()
211+
def populate_mock_pdb_and_remove_test_folders():
212+
proj_root = Path('.')
213+
working_cif_file = proj_root / 'data' / 'test' / '7a4d-assembly1.cif'
214+
215+
pytest_root_folder = Path('./test-folder')
216+
data_folder = pytest_root_folder / 'data'
217+
218+
train_folder = data_folder / 'train'
219+
valid_folder = data_folder / 'valid'
220+
test_folder = data_folder / 'test'
221+
222+
train_folder.mkdir(exist_ok = True, parents = True)
223+
valid_folder.mkdir(exist_ok = True, parents = True)
224+
test_folder.mkdir(exist_ok = True, parents = True)
225+
226+
for i in range(100):
227+
shutil.copy2(str(working_cif_file), str(train_folder / f'{i}.cif'))
228+
229+
for i in range(4):
230+
shutil.copy2(str(working_cif_file), str(valid_folder / f'{i}.cif'))
231+
232+
for i in range(2):
233+
shutil.copy2(str(working_cif_file), str(test_folder / f'{i}.cif'))
234+
235+
yield
236+
237+
shutil.rmtree('./test-folder')
238+
239+
def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
240+
241+
alphafold3 = Alphafold3(
242+
dim_atom=8,
243+
dim_atompair=8,
244+
dim_input_embedder_token=8,
245+
dim_single=8,
246+
dim_pairwise=8,
247+
dim_token=8,
248+
dim_atom_inputs=3,
249+
dim_atompair_inputs=1,
250+
atoms_per_window=27,
251+
dim_template_feats=44,
252+
num_dist_bins=38,
253+
confidence_head_kwargs=dict(pairformer_depth=1),
254+
template_embedder_kwargs=dict(pairformer_stack_depth=1),
255+
msa_module_kwargs=dict(depth=1),
256+
pairformer_stack=dict(depth=1),
257+
diffusion_module_kwargs=dict(
258+
atom_encoder_depth=1,
259+
token_transformer_depth=1,
260+
atom_decoder_depth=1,
261+
),
262+
)
263+
264+
dataset = PDBDataset('./test-folder/data/train')
265+
valid_dataset = PDBDataset('./test-folder/data/valid')
266+
test_dataset = PDBDataset('./test-folder/data/test')
267+
268+
# test saving and loading from Alphafold3, independent of lightning
269+
270+
dataloader = DataLoader(dataset, batch_size = 2)
271+
inputs = next(iter(dataloader))
272+
273+
alphafold3.eval()
274+
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
275+
before_distogram = breakdown.distogram
276+
277+
path = './test-folder/nested/folder/af3'
278+
alphafold3.save(path, overwrite = True)
279+
280+
# load from scratch, along with saved hyperparameters
281+
282+
alphafold3 = Alphafold3.init_and_load(path)
283+
284+
alphafold3.eval()
285+
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
286+
after_distogram = breakdown.distogram
287+
288+
assert torch.allclose(before_distogram, after_distogram)
289+
290+
# test training + validation
291+
292+
trainer = Trainer(
293+
alphafold3,
294+
dataset = dataset,
295+
valid_dataset = valid_dataset,
296+
test_dataset = test_dataset,
297+
accelerator = 'cpu',
298+
num_train_steps = 2,
299+
batch_size = 1,
300+
valid_every = 1,
301+
grad_accum_every = 2,
302+
checkpoint_every = 1,
303+
checkpoint_folder = './test-folder/checkpoints',
304+
overwrite_checkpoints = True,
305+
ema_kwargs = dict(
306+
use_foreach = True,
307+
update_after_step = 0,
308+
update_every = 1
309+
)
310+
)
311+
312+
trainer()
313+
314+
# assert checkpoints created
315+
316+
assert Path(f'./test-folder/checkpoints/({trainer.train_id})_af3.ckpt.1.pt').exists()
317+
318+
# assert can load latest checkpoint by loading from a directory
319+
320+
trainer.load('./test-folder/checkpoints', strict = False)
321+
322+
assert exists(trainer.model_loaded_from_path)
323+
324+
# saving and loading from trainer
325+
326+
trainer.save('./test-folder/nested/folder2/training.pt', overwrite = True)
327+
trainer.load('./test-folder/nested/folder2/training.pt', strict = False)
328+
329+
# allow for only loading model, needed for fine-tuning logic
330+
331+
trainer.load('./test-folder/nested/folder2/training.pt', only_model = True, strict = False)
332+
333+
# also allow for loading Alphafold3 directly from training ckpt
334+
335+
alphafold3 = Alphafold3.init_and_load('./test-folder/nested/folder2/training.pt')
336+
207337
# test use of collation fn outside of trainer
208338

209339
def test_collate_fn():

0 commit comments

Comments
 (0)