Skip to content

Commit 565577a

Browse files
committed
take care of data collation, fix some bugs due to inplace ops
1 parent 38a2a6c commit 565577a

File tree

5 files changed

+84
-14
lines changed

5 files changed

+84
-14
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from alphafold3_pytorch.trainer import (
3636
Trainer,
37+
DataLoader,
3738
Alphafold3Input
3839
)
3940

alphafold3_pytorch/alphafold3.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def repeat_consecutive_with_lens(
212212

213213
output_indices = torch.zeros((batch, max_len + 1), device = device, dtype = torch.long)
214214

215-
indices.masked_fill_(~mask, max_len) # scatter to sink position for padding
215+
indices = indices.masked_fill(~mask, max_len) # scatter to sink position for padding
216216
indices = rearrange(indices, 'b n w -> b (n w)')
217217

218218
# scatter
@@ -3062,6 +3062,15 @@ def forward(
30623062

30633063
atom_seq_len = atom_inputs.shape[-2]
30643064

3065+
# soft validate
3066+
3067+
valid_atom_len_mask = residue_atom_lens >= 0
3068+
3069+
residue_atom_lens = residue_atom_lens.masked_fill(~valid_atom_len_mask, 0)
3070+
residue_atom_indices = residue_atom_indices.masked_fill(~valid_atom_len_mask, 0)
3071+
3072+
assert (residue_atom_indices < residue_atom_lens)[valid_atom_len_mask].all(), 'residue_atom_indices cannot have an index that exceeds the length of the atoms for that residue as given by residue_atom_lens'
3073+
30653074
assert exists(residue_atom_lens) or exists(atom_mask)
30663075

30673076
# if atompair inputs are not windowed, window it
@@ -3079,7 +3088,7 @@ def forward(
30793088
# handle offsets for residue atom indices
30803089

30813090
if exists(residue_atom_indices):
3082-
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
3091+
residue_atom_indices = residue_atom_indices + F.pad(residue_atom_lens, (-1, 1), value = 0)
30833092

30843093
# get atom sequence length and residue sequence length depending on whether using packed atomic seq
30853094

@@ -3118,7 +3127,7 @@ def forward(
31183127

31193128
token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
31203129
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
3121-
token_bond.masked_fill_(diagonal, False)
3130+
token_bond = token_bond.masked_fill(diagonal, False)
31223131
else:
31233132
seq_arange = torch.arange(seq_len, device = self.device)
31243133
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1

alphafold3_pytorch/trainer.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
from pathlib import Path
44

55
from alphafold3_pytorch.alphafold3 import Alphafold3
6+
from alphafold3_pytorch.attention import pad_at_dim
67

7-
from typing import TypedDict
8+
from typing import TypedDict, List
89
from alphafold3_pytorch.typing import (
910
typecheck,
1011
Int, Bool, Float
1112
)
1213

1314
import torch
15+
from torch import Tensor
1416
from torch.optim import Adam, Optimizer
15-
from torch.utils.data import Dataset, DataLoader
17+
from torch.utils.data import Dataset, DataLoader as OrigDataLoader
18+
from torch.nn.utils.rnn import pad_sequence
1619
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
1720

1821
from ema_pytorch import EMA
@@ -24,7 +27,7 @@
2427
@typecheck
2528
class Alphafold3Input(TypedDict):
2629
atom_inputs: Float['m dai']
27-
residue_atom_lens: Int['n 2']
30+
residue_atom_lens: Int[' n']
2831
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
2932
additional_residue_feats: Float['n 10']
3033
templates: Float['t n n dt']
@@ -70,6 +73,62 @@ def accum_dict(
7073

7174
return past_losses
7275

76+
# dataloader and collation fn
77+
78+
@typecheck
79+
def collate_af3_inputs(
80+
inputs: List[Alphafold3Input],
81+
int_pad_value = -1
82+
):
83+
# separate input dictionary into keys and values
84+
85+
keys = inputs[0].keys()
86+
inputs = [i.values() for i in inputs]
87+
88+
outputs = []
89+
90+
for grouped in zip(*inputs):
91+
# if all None, just return None
92+
93+
if not any([*map(exists, grouped)]):
94+
outputs.append(None)
95+
continue
96+
97+
# use -1 for padding int values, for assuming int are labels - if not, handle within alphafold3
98+
99+
pad_value = int_pad_value if grouped[0].dtype in (torch.int, torch.long) else 0
100+
101+
# get the max lengths across all dimensions
102+
103+
shapes_as_tensor = torch.stack([Tensor(tuple(g.shape)) for g in grouped], dim = -1)
104+
105+
max_lengths = shapes_as_tensor.int().amax(dim = -1)
106+
107+
# pad across all dimensions
108+
109+
padded_inputs = []
110+
111+
for inp in grouped:
112+
for dim, max_length in enumerate(max_lengths.tolist()):
113+
inp = pad_at_dim(inp, (0, max_length - inp.shape[dim]), value = pad_value, dim = dim)
114+
115+
padded_inputs.append(inp)
116+
117+
# stack
118+
119+
stacked = torch.stack(padded_inputs)
120+
121+
outputs.append(stacked)
122+
123+
# reconstitute dictionary
124+
125+
return dict(tuple(zip(keys, outputs)))
126+
127+
def DataLoader(*args, **kwargs):
128+
return OrigDataLoader(*args, collate_fn = collate_af3_inputs, **kwargs)
129+
130+
# default scheduler used in paper w/ warmup
131+
73132
def default_lambda_lr_fn(steps):
74133
# 1000 step warmup
75134

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
os.environ['TYPECHECK'] = 'True'
33

44
from pathlib import Path
5+
from random import randrange
56

67
import pytest
78
import torch
8-
from torch.utils.data import Dataset, DataLoader
9+
from torch.utils.data import Dataset
910

1011
from alphafold3_pytorch import (
1112
Alphafold3,
1213
Alphafold3Input,
14+
DataLoader,
1315
Trainer
1416
)
1517

@@ -19,25 +21,24 @@ class MockAtomDataset(Dataset):
1921
def __init__(
2022
self,
2123
data_length,
22-
seq_len = 16,
24+
max_seq_len = 16,
2325
atoms_per_window = 4
2426
):
2527
self.data_length = data_length
26-
self.seq_len = seq_len
28+
self.max_seq_len = max_seq_len
2729
self.atoms_per_window = atoms_per_window
28-
self.atom_seq_len = seq_len * atoms_per_window
2930

3031
def __len__(self):
3132
return self.data_length
3233

3334
def __getitem__(self, idx):
34-
seq_len = self.seq_len
35-
atom_seq_len = self.atom_seq_len
35+
seq_len = randrange(1, self.max_seq_len)
36+
atom_seq_len = self.atoms_per_window * seq_len
3637

3738
atom_inputs = torch.randn(atom_seq_len, 77)
3839
atompair_inputs = torch.randn(atom_seq_len, atom_seq_len, 5)
3940

40-
residue_atom_lens = torch.randint(0, self.atoms_per_window, (seq_len,))
41+
residue_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
4142
additional_residue_feats = torch.randn(seq_len, 10)
4243

4344
templates = torch.randn(2, seq_len, seq_len, 44)

0 commit comments

Comments
 (0)