Skip to content

Commit 5591994

Browse files
committed
move some input related functions from trainer.py to inputs.py and allow for passing in Alphafold3Input directly into Alphafold3, simplifying some cli and app logic
1 parent fab65d6 commit 5591994

File tree

11 files changed

+202
-208
lines changed

11 files changed

+202
-208
lines changed

README.md

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,7 @@ An example with molecule level input handling
170170

171171
```python
172172
import torch
173-
174-
from alphafold3_pytorch import (
175-
Alphafold3,
176-
Alphafold3Input,
177-
alphafold3_inputs_to_batched_atom_input
178-
)
173+
from alphafold3_pytorch import Alphafold3, Alphafold3Input
179174

180175
contrived_protein = 'AG'
181176

@@ -193,8 +188,6 @@ eval_alphafold3_input = Alphafold3Input(
193188
proteins = [contrived_protein]
194189
)
195190

196-
batched_atom_input = alphafold3_inputs_to_batched_atom_input(train_alphafold3_input, atoms_per_window = 27)
197-
198191
# training
199192

200193
alphafold3 = Alphafold3(
@@ -222,15 +215,13 @@ alphafold3 = Alphafold3(
222215
)
223216
)
224217

225-
loss = alphafold3(**batched_atom_input.model_forward_dict())
218+
loss = alphafold3.forward_with_alphafold3_inputs([train_alphafold3_input])
226219
loss.backward()
227220

228221
# sampling
229222

230-
batched_eval_atom_input = alphafold3_inputs_to_batched_atom_input(eval_alphafold3_input, atoms_per_window = 27)
231-
232223
alphafold3.eval()
233-
sampled_atom_pos = alphafold3(**batched_eval_atom_input.model_forward_dict())
224+
sampled_atom_pos = alphafold3.forward_with_alphafold3_inputs(eval_alphafold3_input)
234225

235226
assert sampled_atom_pos.shape == (1, (5 + 4), 3)
236227
```

alphafold3_pytorch/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@
5050
PDBDataset,
5151
maybe_transform_to_atom_input,
5252
maybe_transform_to_atom_inputs,
53+
alphafold3_inputs_to_batched_atom_input,
54+
collate_inputs_to_batched_atom_input,
55+
pdb_inputs_to_batched_atom_input,
5356
)
5457

5558
from alphafold3_pytorch.trainer import (
5659
Trainer,
5760
DataLoader,
58-
collate_inputs_to_batched_atom_input,
59-
alphafold3_inputs_to_batched_atom_input,
60-
pdb_inputs_to_batched_atom_input,
6161
)
6262

6363
from alphafold3_pytorch.configs import (

alphafold3_pytorch/alphafold3.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@
6767
NUM_MSA_ONE_HOT,
6868
DEFAULT_NUM_MOLECULE_MODS,
6969
ADDITIONAL_MOLECULE_FEATS,
70+
hard_validate_atom_indices_ascending,
7071
BatchedAtomInput,
71-
hard_validate_atom_indices_ascending
72+
Alphafold3Input,
73+
alphafold3_inputs_to_batched_atom_input,
7274
)
7375

7476
from alphafold3_pytorch.common.biomolecule import (
@@ -6345,6 +6347,18 @@ def shrink_and_perturb_(
63456347

63466348
return self
63476349

6350+
@typecheck
6351+
def forward_with_alphafold3_inputs(
6352+
self,
6353+
alphafold3_inputs: Alphafold3Input | list[Alphafold3Input],
6354+
**kwargs
6355+
):
6356+
if not isinstance(alphafold3_inputs, list):
6357+
alphafold3_inputs = [alphafold3_inputs]
6358+
6359+
batched_atom_inputs = alphafold3_inputs_to_batched_atom_input(alphafold3_inputs, atoms_per_window = self.w)
6360+
return self.forward(**batched_atom_inputs.model_forward_dict(), **kwargs)
6361+
63486362
@typecheck
63496363
def forward(
63506364
self,

alphafold3_pytorch/app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ def fold(protein):
2020
proteins = [protein]
2121
)
2222

23-
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input, atoms_per_window = model.atoms_per_window)
24-
2523
model.eval()
26-
atom_pos, = model(**batched_atom_input.model_forward_dict())
24+
atom_pos, = model.forward_with_alphafold3_inputs(alphafold3_input)
2725

2826
return str(atom_pos.tolist())
2927

alphafold3_pytorch/cli.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
from alphafold3_pytorch import (
99
Alphafold3,
10-
Alphafold3Input,
11-
alphafold3_inputs_to_batched_atom_input
10+
Alphafold3Input
1211
)
1312

1413
from Bio.PDB.mmcifio import MMCIFIO
@@ -40,10 +39,8 @@ def cli(
4039

4140
alphafold3 = Alphafold3.init_and_load(checkpoint_path)
4241

43-
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input, atoms_per_window = alphafold3.atoms_per_window)
44-
4542
alphafold3.eval()
46-
structure, = alphafold3(**batched_atom_input.model_forward_dict(), return_bio_pdb_structures = True)
43+
structure, = alphafold3.forward_with_alphafold3_inputs(alphafold3_input, return_bio_pdb_structures = True)
4744

4845
output_path = Path(output)
4946
output_path.parents[0].mkdir(exist_ok = True, parents = True)

alphafold3_pytorch/inputs.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@
4141
from torch.nn.utils.rnn import pad_sequence
4242
from torch.utils.data import Dataset
4343

44-
from alphafold3_pytorch.common import amino_acid_constants, dna_constants, rna_constants
44+
from alphafold3_pytorch.common import (
45+
amino_acid_constants,
46+
dna_constants,
47+
rna_constants
48+
)
4549
from alphafold3_pytorch.common.biomolecule import (
4650
Biomolecule,
4751
_from_mmcif_object,
@@ -86,6 +90,13 @@
8690
from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck
8791
from alphafold3_pytorch.utils.utils import default, exists, first, not_exists
8892

93+
from alphafold3_pytorch.attention import (
94+
full_pairwise_repr_to_windowed,
95+
full_attn_bias_to_windowed,
96+
pad_at_dim,
97+
pad_or_slice_to
98+
)
99+
89100
# silence RDKit's warnings
90101

91102
RDLogger.DisableLog("rdApp.*")
@@ -1821,6 +1832,17 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
18211832

18221833
return molecule_input
18231834

1835+
@typecheck
1836+
def alphafold3_inputs_to_batched_atom_input(
1837+
inp: Alphafold3Input | List[Alphafold3Input],
1838+
**collate_kwargs
1839+
) -> BatchedAtomInput:
1840+
1841+
if isinstance(inp, Alphafold3Input):
1842+
inp = [inp]
1843+
1844+
atom_inputs = maybe_transform_to_atom_inputs(inp)
1845+
return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs)
18241846

18251847
# pdb input
18261848

@@ -3366,6 +3388,17 @@ def pdb_input_to_molecule_input(
33663388

33673389
return molecule_input
33683390

3391+
@typecheck
3392+
def pdb_inputs_to_batched_atom_input(
3393+
inp: PDBInput | List[PDBInput],
3394+
**collate_kwargs
3395+
) -> BatchedAtomInput:
3396+
3397+
if isinstance(inp, PDBInput):
3398+
inp = [inp]
3399+
3400+
atom_inputs = maybe_transform_to_atom_inputs(inp)
3401+
return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs)
33693402

33703403
# datasets
33713404

@@ -3517,6 +3550,142 @@ def __getitem__(self, idx: int | str, max_attempts: int = 10) -> PDBInput | Atom
35173550

35183551
return i
35193552

3553+
# collation function
3554+
3555+
@typecheck
3556+
def collate_inputs_to_batched_atom_input(
3557+
inputs: List,
3558+
int_pad_value = -1,
3559+
atoms_per_window: int | None = None,
3560+
map_input_fn: Callable | None = None,
3561+
transform_to_atom_inputs: bool = True,
3562+
) -> BatchedAtomInput:
3563+
3564+
if exists(map_input_fn):
3565+
inputs = [map_input_fn(i) for i in inputs]
3566+
3567+
# go through all the inputs
3568+
# and for any that is not AtomInput, try to transform it with the registered input type to corresponding registered function
3569+
3570+
if transform_to_atom_inputs:
3571+
atom_inputs = maybe_transform_to_atom_inputs(inputs)
3572+
3573+
if len(atom_inputs) < len(inputs):
3574+
# if some of the `inputs` could not be converted into `atom_inputs`,
3575+
# randomly select a subset of the `atom_inputs` to duplicate to match
3576+
# the expected number of `atom_inputs`
3577+
assert (
3578+
len(atom_inputs) > 0
3579+
), "No `AtomInput` objects could be created for the current batch."
3580+
atom_inputs = random.choices(atom_inputs, k=len(inputs)) # nosec
3581+
else:
3582+
assert all(isinstance(i, AtomInput) for i in inputs), (
3583+
"When `transform_to_atom_inputs=False`, all provided "
3584+
"inputs must be of type `AtomInput`."
3585+
)
3586+
atom_inputs = inputs
3587+
3588+
assert all(isinstance(i, AtomInput) for i in atom_inputs), (
3589+
"All inputs must be of type `AtomInput`. "
3590+
"If you want to transform the inputs to `AtomInput`, "
3591+
"set `transform_to_atom_inputs=True`."
3592+
)
3593+
3594+
# take care of windowing the atompair_inputs and atompair_ids if they are not windowed already
3595+
3596+
if exists(atoms_per_window):
3597+
for atom_input in atom_inputs:
3598+
atompair_inputs = atom_input.atompair_inputs
3599+
atompair_ids = atom_input.atompair_ids
3600+
3601+
atompair_inputs_is_windowed = atompair_inputs.ndim == 4
3602+
3603+
if not atompair_inputs_is_windowed:
3604+
atom_input.atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
3605+
3606+
if exists(atompair_ids):
3607+
atompair_ids_is_windowed = atompair_ids.ndim == 3
3608+
3609+
if not atompair_ids_is_windowed:
3610+
atom_input.atompair_ids = full_attn_bias_to_windowed(atompair_ids, window_size = atoms_per_window)
3611+
3612+
# separate input dictionary into keys and values
3613+
3614+
keys = list(atom_inputs[0].dict().keys())
3615+
atom_inputs = [i.dict().values() for i in atom_inputs]
3616+
3617+
outputs = []
3618+
3619+
for key, grouped in zip(keys, zip(*atom_inputs)):
3620+
# if all None, just return None
3621+
3622+
not_none_grouped = [*filter(exists, grouped)]
3623+
3624+
if len(not_none_grouped) == 0:
3625+
outputs.append(None)
3626+
continue
3627+
3628+
# collate lists for uncollatable fields
3629+
3630+
if key in UNCOLLATABLE_ATOM_INPUT_FIELDS:
3631+
outputs.append(grouped)
3632+
continue
3633+
3634+
# default to empty tensor for any Nones
3635+
3636+
one_tensor = not_none_grouped[0]
3637+
3638+
dtype = one_tensor.dtype
3639+
ndim = one_tensor.ndim
3640+
3641+
# use -1 for padding int values, for assuming int are labels - if not, handle within alphafold3
3642+
3643+
if key in ATOM_DEFAULT_PAD_VALUES:
3644+
pad_value = ATOM_DEFAULT_PAD_VALUES[key]
3645+
elif dtype in (torch.int, torch.long):
3646+
pad_value = int_pad_value
3647+
elif dtype == torch.bool:
3648+
pad_value = False
3649+
else:
3650+
pad_value = 0.
3651+
3652+
# get the max lengths across all dimensions
3653+
3654+
shapes_as_tensor = torch.stack([tensor(tuple(g.shape) if exists(g) else ((0,) * ndim)).int() for g in grouped], dim = -1)
3655+
3656+
max_lengths = shapes_as_tensor.amax(dim = -1)
3657+
3658+
default_tensor = torch.full(max_lengths.tolist(), pad_value, dtype = dtype)
3659+
3660+
# pad across all dimensions
3661+
3662+
padded_inputs = []
3663+
3664+
for inp in grouped:
3665+
3666+
if not exists(inp):
3667+
padded_inputs.append(default_tensor)
3668+
continue
3669+
3670+
for dim, max_length in enumerate(max_lengths.tolist()):
3671+
inp = pad_at_dim(inp, (0, max_length - inp.shape[dim]), value = pad_value, dim = dim)
3672+
3673+
padded_inputs.append(inp)
3674+
3675+
# stack
3676+
3677+
stacked = torch.stack(padded_inputs)
3678+
3679+
outputs.append(stacked)
3680+
3681+
# batched atom input dictionary
3682+
3683+
batched_atom_input_dict = dict(tuple(zip(keys, outputs)))
3684+
3685+
# reconstitute dictionary
3686+
3687+
batched_atom_inputs = BatchedAtomInput(**batched_atom_input_dict)
3688+
return batched_atom_inputs
35203689

35213690
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
35223691
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation

0 commit comments

Comments
 (0)