Skip to content

Commit aa47e39

Browse files
committed
let molecule_lengthed_molecule_input_to_atom_input construct the one_token_per_atom automatically from detecting ligands and modified biomolecules
1 parent c8d90b1 commit aa47e39

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ class MoleculeLengthMoleculeInput:
691691
src_tgt_atom_indices: Int['n 2']
692692
token_bonds: Bool['n n'] | None = None
693693
one_token_per_atom: List[bool] | None = None
694-
is_molecule_mod: Bool['n num_mods'] | None = None
694+
is_molecule_mod: Bool['n num_mods'] | Bool['n'] | None = None
695695
molecule_atom_indices: List[int | None] | None = None
696696
distogram_atom_indices: List[int | None] | None = None
697697
missing_atom_indices: List[Int[' _'] | None] | None = None
@@ -724,11 +724,23 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
724724

725725
# derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
726726

727-
assert len(molecules) == len(i.one_token_per_atom)
728-
729727
atoms_per_molecule = tensor([mol.GetNumAtoms() for mol in molecules])
730728
ones = torch.ones_like(atoms_per_molecule)
731729

730+
# `is_molecule_mod` can either be
731+
# 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
732+
# 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
733+
734+
is_molecule_mod = i.is_molecule_mod
735+
is_molecule_any_mod = False
736+
737+
if exists(is_molecule_mod):
738+
if i.is_molecule_mod.ndim == 2:
739+
is_molecule_any_mod = is_molecule_mod.any(dim = -1)
740+
else:
741+
is_molecule_any_mod = is_molecule_mod
742+
is_molecule_mod = None
743+
732744
# get `one_token_per_atom`, which can be fully customizable
733745

734746
if exists(i.one_token_per_atom):
@@ -737,7 +749,9 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
737749
# if which molecule is `one_token_per_atom` is not passed in
738750
# default to what the paper did, which is ligands and any modified biomolecule
739751
is_ligand = i.is_molecule_types[..., IS_LIGAND_INDEX]
740-
one_token_per_atom = is_ligand | is_molecule_mod.any(dim = -1)
752+
one_token_per_atom = is_ligand | is_molecule_any_mod
753+
754+
assert len(molecules) == len(one_token_per_atom)
741755

742756
# derive the number of repeats needed to expand molecule lengths to token lengths
743757

@@ -782,7 +796,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
782796
molecule_atom_indices = repeat_interleave(i.molecule_atom_indices, token_repeats)
783797

784798
msa = maybe(repeat_interleave)(i.msa, token_repeats, dim = -2)
785-
is_molecule_mod = maybe(repeat_interleave)(i.is_molecule_mod, token_repeats, dim = -2)
799+
is_molecule_mod = maybe(repeat_interleave)(i.is_molecule_mod, token_repeats, dim = 0)
786800

787801
templates = maybe(repeat_interleave)(i.templates, token_repeats, dim = -3)
788802
templates = maybe(repeat_interleave)(templates, token_repeats, dim = -2)
@@ -1340,12 +1354,6 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
13401354
*mol_metal_ions
13411355
]
13421356

1343-
one_token_per_atom = [
1344-
*((False,) * len(molecules_without_ligands)),
1345-
*((True,) * len(mol_ligands)),
1346-
*((False,) * len(mol_metal_ions)),
1347-
]
1348-
13491357
for mol in molecules:
13501358
Chem.SanitizeMol(mol)
13511359

@@ -1498,7 +1506,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
14981506

14991507
molecule_input = MoleculeLengthMoleculeInput(
15001508
molecules=molecules,
1501-
one_token_per_atom=one_token_per_atom,
15021509
molecule_atom_indices=molecule_atom_indices,
15031510
distogram_atom_indices=distogram_atom_indices,
15041511
molecule_ids=molecule_ids,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.86"
3+
version = "0.2.88"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)