Skip to content

Commit 95c99d1

Browse files
committed
able to precompute atom dataset from a pdb dataset all the way from a training config and take a training step
1 parent 68a3ac7 commit 95c99d1

File tree

8 files changed

+152
-33
lines changed

8 files changed

+152
-33
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
name: Pytest
22
on: [push, pull_request]
33

4+
env:
5+
TYPECHECK: True
6+
47
jobs:
58
build:
69

alphafold3_pytorch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
AtomDataset,
4646
PDBInput,
4747
PDBDataset,
48-
DatasetWithReturnedIndex,
4948
maybe_transform_to_atom_input,
5049
maybe_transform_to_atom_inputs,
5150
)

alphafold3_pytorch/confidence.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
repeat_consecutive_with_lens
1212
)
1313

14-
1514
from alphafold3_pytorch.tensor_typing import (
1615
Float,
1716
Int,

alphafold3_pytorch/configs.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from alphafold3_pytorch.inputs import (
99
AtomDataset,
10-
PDBDataset
10+
PDBDataset,
11+
pdb_dataset_to_atom_inputs
1112
)
1213

1314
from alphafold3_pytorch.trainer import (
@@ -145,6 +146,7 @@ class DatasetConfig(BaseModelWithExtra):
145146
train_folder: DirectoryPath
146147
valid_folder: DirectoryPath | None = None
147148
test_folder: DirectoryPath | None = None
149+
convert_pdb_to_atom: bool = False
148150
train_weighted_sampler: WeightedPDBSamplerConfig | None = None
149151
kwargs: dict = dict()
150152

@@ -219,6 +221,11 @@ def create_instance(
219221
dataset_type = dataset_config.dataset_type
220222
dataset_kwargs = dataset_config.kwargs
221223

224+
convert_pdb_to_atom = dataset_config.convert_pdb_to_atom
225+
226+
if convert_pdb_to_atom:
227+
assert dataset_type == 'atom', 'must be `atom` dataset_type if `convert_pdb_to_atom` is set to True'
228+
222229
if dataset_type == 'pdb':
223230
dataset_klass = PDBDataset
224231
elif dataset_type == 'atom':
@@ -230,6 +237,11 @@ def create_instance(
230237

231238
if exists(train_folder):
232239
assert 'dataset' not in trainer_kwargs
240+
241+
if convert_pdb_to_atom:
242+
pdb_dataset = PDBDataset(train_folder, **dataset_kwargs)
243+
train_folder = pdb_dataset_to_atom_inputs(pdb_dataset)
244+
233245
dataset = dataset_klass(train_folder, **dataset_kwargs)
234246
trainer_kwargs.update(dataset = dataset)
235247

alphafold3_pytorch/inputs.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from itertools import groupby
99
from collections import defaultdict
10+
from collections.abc import Iterable
1011
from dataclasses import asdict, dataclass, field
1112
from typing import Any, Callable, List, Literal, Set, Tuple, Type
1213

@@ -187,11 +188,12 @@ def dict(self):
187188
@typecheck
188189
def atom_input_to_file(
189190
atom_input: AtomInput,
190-
path: str,
191+
path: str | Path,
191192
overwrite: bool = False
192193
) -> Path:
193194

194-
path = Path(path)
195+
if isinstance(path, str):
196+
path = Path(path)
195197

196198
if not overwrite:
197199
assert not path.exists()
@@ -211,6 +213,53 @@ def file_to_atom_input(path: str | Path) -> AtomInput:
211213
atom_input_dict = torch.load(str(path))
212214
return AtomInput(**atom_input_dict)
213215

216+
@typecheck
217+
def pdb_dataset_to_atom_inputs(
218+
pdb_dataset: PDBDataset,
219+
*,
220+
output_atom_folder: str | Path | None = None,
221+
indices: Iterable | None = None,
222+
return_atom_dataset = False,
223+
verbose = True
224+
) -> Path | AtomDataset:
225+
226+
if not exists(output_atom_folder):
227+
pdb_folder = Path(pdb_dataset.folder).resolve()
228+
parent_folder = pdb_folder.parents[0]
229+
output_atom_folder = parent_folder / f'{pdb_folder.stem}.atom-inputs'
230+
231+
if isinstance(output_atom_folder, str):
232+
output_atom_folder = Path(output_atom_folder)
233+
234+
if not exists(indices):
235+
indices = torch.randperm(len(pdb_dataset)).tolist()
236+
237+
indices = iter(indices)
238+
239+
to_atom_input_fn = compose(
240+
pdb_input_to_molecule_input,
241+
molecule_to_atom_input
242+
)
243+
244+
while index := next(indices, None):
245+
if not exists(index):
246+
break
247+
248+
pdb_input = pdb_dataset[index]
249+
250+
atom_input = to_atom_input_fn(pdb_input)
251+
atom_input_path = output_atom_folder / f'{index}.pt'
252+
253+
atom_input_to_file(atom_input, atom_input_path)
254+
255+
if verbose:
256+
logger.info(f'converted pdb input with index {index} to {str(atom_input_path)}')
257+
258+
if not return_atom_dataset:
259+
return output_atom_folder
260+
261+
return AtomDataset(output_atom_folder)
262+
214263
# Atom dataset that returns a AtomInput based on folders of atom inputs stored on disk
215264

216265
class AtomDataset(Dataset):
@@ -221,11 +270,13 @@ def __init__(
221270
if isinstance(folder, str):
222271
folder = Path(folder)
223272

224-
assert folder.exists() and folder.is_dir()
273+
assert folder.exists() and folder.is_dir(), f'atom dataset not found at {str(folder)}'
225274

226275
self.folder = folder
227276
self.files = [*folder.glob('**/*.pt')]
228277

278+
assert len(self) > 0, f'no valid atom .pt files found at {str(folder)}'
279+
229280
def __len__(self):
230281
return len(self.files)
231282

@@ -1919,19 +1970,6 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
19191970

19201971
# datasets
19211972

1922-
# dataset wrapper for returning index along with dataset item
1923-
# for caching logic both integrated into trainer and for precaching
1924-
1925-
class DatasetWithReturnedIndex(Dataset):
1926-
def __init__(self, dataset: Dataset):
1927-
self.dataset = dataset
1928-
1929-
def __len__(self):
1930-
return len(self.dataset)
1931-
1932-
def __getitem__(self, idx):
1933-
return idx, self.dataset[idx]
1934-
19351973
# PDB dataset that returns a PDBInput based on folder
19361974

19371975
class PDBDataset(Dataset):
@@ -1953,7 +1991,9 @@ def __init__(
19531991
if isinstance(folder, str):
19541992
folder = Path(folder)
19551993

1956-
assert folder.exists() and folder.is_dir()
1994+
assert folder.exists() and folder.is_dir(), f'{str(folder)} does not exist for PDBDataset'
1995+
self.folder = folder
1996+
19571997
self.files = {
19581998
os.path.splitext(os.path.basename(file.name))[0]: file
19591999
for file in folder.glob(os.path.join("**", "*.cif"))
@@ -1967,6 +2007,8 @@ def __init__(
19672007
self.training = training
19682008
self.pdb_input_kwargs = pdb_input_kwargs
19692009

2010+
assert len(self) > 0, f'no valid mmcifs / pdbs found at {str(folder)}'
2011+
19702012
def __len__(self):
19712013
"""Return the number of PDB mmCIF files in the dataset."""
19722014
return len(self.files)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.42"
3+
version = "0.2.43"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
---
2+
model:
3+
dim_atom: 4
4+
dim_atompair: 4
5+
dim_input_embedder_token: 4
6+
dim_single: 4
7+
dim_pairwise: 4
8+
dim_token: 4
9+
dim_atom_inputs: 3
10+
dim_atompair_inputs: 1
11+
dim_template_model: 8
12+
atoms_per_window: 27
13+
dim_template_feats: 44
14+
num_dist_bins: 38
15+
ignore_index: -1
16+
num_dist_bins: null
17+
num_plddt_bins: 50
18+
num_pde_bins: 64
19+
num_pae_bins: 64
20+
sigma_data: 16
21+
diffusion_num_augmentations: 4
22+
loss_confidence_weight: 0.0001
23+
loss_distogram_weight: 0.01
24+
loss_diffusion_weight: 4.
25+
confidence_head_kwargs:
26+
pairformer_depth: 1
27+
template_embedder_kwargs:
28+
pairformer_stack_depth: 1
29+
msa_module_kwargs:
30+
depth: 1
31+
pairformer_stack:
32+
depth: 1
33+
pair_bias_attn_dim_head: 4
34+
pair_bias_attn_heads: 2
35+
diffusion_module_kwargs:
36+
atom_encoder_depth: 1
37+
token_transformer_depth: 1
38+
atom_decoder_depth: 1
39+
atom_decoder_kwargs:
40+
attn_pair_bias_kwargs:
41+
dim_head: 4
42+
atom_encoder_kwargs:
43+
attn_pair_bias_kwargs:
44+
dim_head: 4
45+
num_train_steps: 1
46+
batch_size: 1
47+
grad_accum_every: 1
48+
valid_every: 1
49+
use_ema: false
50+
ema_decay: 0.999
51+
lr: 0.0001
52+
clip_grad_norm: 10.
53+
accelerator: cpu
54+
checkpoint_prefix: af3.ckpt.
55+
checkpoint_every: 1
56+
checkpoint_folder: ./checkpoints
57+
overwrite_checkpoints: false
58+
dataset_config:
59+
dataset_type: atom
60+
convert_pdb_to_atom: true
61+
train_folder: ./test-folder/data/train

tests/test_trainer.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from alphafold3_pytorch import (
1313
Alphafold3,
1414
PDBDataset,
15-
DatasetWithReturnedIndex,
1615
AtomInput,
1716
atom_input_to_file,
1817
DataLoader,
@@ -314,17 +313,6 @@ def test_collate_fn():
314313

315314
_, breakdown = alphafold3(**asdict(batched_atom_inputs), return_loss_breakdown = True)
316315

317-
# test use of a dataset wrapper that returns the indices, for caching
318-
319-
def test_dataset_return_index_wrapper():
320-
dataset = MockAtomDataset(5)
321-
wrapped_dataset = DatasetWithReturnedIndex(dataset)
322-
323-
assert len(wrapped_dataset) == len(dataset)
324-
325-
idx, item = wrapped_dataset[3]
326-
assert idx == 3 and isinstance(item, AtomInput)
327-
328316
# test creating trainer + alphafold3 from config
329317

330318
def test_trainer_config():
@@ -387,6 +375,21 @@ def test_trainer_config_with_atom_dataset():
387375

388376
shutil.rmtree(atom_folder, ignore_errors = True)
389377

378+
# test creating trainer + alphafold3 with atom dataset that is precomputed from a pdb dataset
379+
380+
def test_trainer_config_with_atom_dataset_from_pdb_dataset(populate_mock_pdb_and_remove_test_folders):
381+
382+
curr_dir = Path(__file__).parents[0]
383+
trainer_yaml_path = curr_dir / 'configs/trainer_with_atom_dataset_created_from_pdb.yaml'
384+
385+
trainer = create_trainer_from_yaml(trainer_yaml_path)
386+
387+
assert isinstance(trainer, Trainer)
388+
389+
# take a single training step
390+
391+
trainer()
392+
390393
# test creating trainer without model, given when creating instance
391394

392395
def test_trainer_config_without_model():

0 commit comments

Comments
 (0)