Skip to content

Commit 867ba97

Browse files
committed
after a big refactor, now modified biomolecules are expanded into one token per atom, and furthermore, any molecule can be expanded if one wishes to do so
1 parent a1ae7f3 commit 867ba97

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,12 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
677677
@dataclass
678678
class MoleculeLengthMoleculeInput:
679679
molecules: List[Mol]
680-
one_token_per_atom: List[bool]
681680
molecule_ids: Int[' n']
682681
additional_molecule_feats: Int[f'n {ADDITIONAL_MOLECULE_FEATS-1}']
683682
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
684683
src_tgt_atom_indices: Int['n 2']
685684
token_bonds: Bool['n n']
685+
one_token_per_atom: List[bool] | None = None
686686
is_molecule_mod: Bool['n num_mods'] | None = None
687687
molecule_atom_indices: List[int | None] | None = None
688688
distogram_atom_indices: List[int | None] | None = None
@@ -719,7 +719,16 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
719719

720720
atoms_per_molecule = tensor([mol.GetNumAtoms() for mol in molecules])
721721
ones = torch.ones_like(atoms_per_molecule)
722-
one_token_per_atom = tensor(i.one_token_per_atom)
722+
723+
# get `one_token_per_atom`, which can be fully customizable
724+
725+
if exists(i.one_token_per_atom):
726+
one_token_per_atom = tensor(i.one_token_per_atom)
727+
else:
728+
# if which molecule is `one_token_per_atom` is not passed in
729+
# default to what the paper did, which is ligands and any modified biomolecule
730+
is_ligand = i.is_molecule_types[..., IS_LIGAND_INDEX]
731+
one_token_per_atom = is_ligand | is_molecule_mod.any(dim = -1)
723732

724733
# derive the number of repeats needed to expand molecule lengths to token lengths
725734

0 commit comments

Comments
 (0)