Skip to content

Commit 0a8bae5

Browse files
committed
address a few issues and fix molecule_ids not accounting for ligands have one token per atom
1 parent 08d4c55 commit 0a8bae5

File tree

7 files changed

+21
-10
lines changed

7 files changed

+21
-10
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ A fork with full Lightning + Hydra support is being maintained by <a href="https
3030

3131
- <a href="https://github.com/milot-mirdita">Milot</a> for optimizing the PDB dataset clustering script!
3232

33+
- <a href="https://github.com/amorehead">Alex</a> for basically writing the entire gargantuan flow from parsing the PDB all the way to the molecule and atomic inputs for training
34+
3335
- <a href="https://github.com/patrick-kidger">Patrick</a> for <a href="https://docs.kidger.site/jaxtyping/">jaxtyping</a>, <a href="https://github.com/fferflo">Florian</a> for <a href="https://github.com/fferflo/einx">einx</a>, and of course, <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://einops.rocks/">einops</a>
3436

3537
## Install

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3381,7 +3381,7 @@ def forward(
33813381
return_present_sampled_atoms: bool = False,
33823382
num_rollout_steps: int = 20,
33833383
rollout_show_tqdm_pbar: bool = False
3384-
) -> Float['b m 3'] | Float['l 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
3384+
) -> Float['b m 3'] | List[Float['l 3']] | Float[''] | Tuple[Float[''], LossBreakdown]:
33853385

33863386
atom_seq_len = atom_inputs.shape[-2]
33873387

@@ -3624,8 +3624,9 @@ def forward(
36243624

36253625
if exists(atom_mask):
36263626
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
3627+
36273628
if exists(missing_atom_mask) and return_present_sampled_atoms:
3628-
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
3629+
sampled_atom_pos = [one_sampled_atom_pos[~one_missing_atom_mask] for one_sampled_atom_pos, one_missing_atom_mask in zip(sampled_atom_pos, missing_atom_mask)]
36293630

36303631
return sampled_atom_pos
36313632

alphafold3_pytorch/inputs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,8 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
707707
(mol_from_smile(ligand) if isinstance(ligand, str) else ligand) for ligand in ligands
708708
]
709709

710-
molecule_ids.append(tensor([ligand_id] * len(mol_ligands)))
710+
for mol_ligand in mol_ligands:
711+
molecule_ids.append(tensor([ligand_id] * mol_ligand.GetNumAtoms()))
711712

712713
# create the molecule input
713714

@@ -819,8 +820,6 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
819820
# handle molecule ids
820821

821822
molecule_ids = torch.cat(molecule_ids).long()
822-
# TODO: do not pad this with zeros anymore, as it will mistakenly treat padded tokens as `ALA`
823-
molecule_ids = pad_to_len(molecule_ids, num_tokens)
824823

825824
# handle atom_parent_ids
826825
# this governs in the atom encoder / decoder, which atom attends to which

alphafold3_pytorch/utils/model_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import wraps
2-
from typing import List, Tuple, Union
2+
from typing import List, Tuple, Union, Any
33

44
import einx
55
import torch
@@ -18,7 +18,6 @@
1818

1919
# default scheduler used in paper w/ warmup
2020

21-
2221
def default_lambda_lr_fn(steps: int) -> float:
2322
"""Default lambda learning rate function.
2423

alphafold3_pytorch/utils/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import numpy as np
22

3-
from typing import Any
3+
from typing import Any, List
4+
5+
def first(arr: List) -> Any:
6+
"""
7+
Returns first element of list
8+
9+
:param arr: the list
10+
:return: the element
11+
"""
12+
return arr[0]
413

514

615
def exists(val: Any) -> bool:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.134"
3+
version = "0.1.135"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def test_pdbinput_input():
184184
batched_eval_atom_input = pdb_inputs_to_batched_atom_input(eval_pdb_input, atoms_per_window=27)
185185

186186
alphafold3.eval()
187-
sampled_atom_pos = alphafold3(
187+
188+
sampled_atom_pos, = alphafold3(
188189
**batched_eval_atom_input.dict(), return_loss=False, return_present_sampled_atoms=True
189190
)
190191

0 commit comments

Comments
 (0)