Skip to content

Commit df8b26c

Browse files
committed
wire up the derivation of the ligand frame from the atom positions
1 parent fd26b70 commit df8b26c

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,20 @@ def offset_only_positive(t, offset):
158158
t_offsetted = t + offset
159159
return torch.where(is_positive, t_offsetted, t)
160160

161+
@typecheck
162+
def remove_consecutive_duplicate(
163+
t: Int['n ...'],
164+
remove_to_value = -1
165+
) -> Int['n ...']:
166+
167+
is_duplicate = t[1:] == t[:-1]
168+
169+
if is_duplicate.ndim == 2:
170+
is_duplicate = is_duplicate.all(dim = -1)
171+
172+
is_duplicate = F.pad(is_duplicate, (1, 0), value = False)
173+
return einx.where('n, n ..., -> n ... ', ~is_duplicate, t, remove_to_value)
174+
161175
def compose(*fns: Callable):
162176
# for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
163177

@@ -1322,6 +1336,10 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
13221336
molecule_atom_indices = offset_only_positive(molecule_atom_indices, atom_indices_offsets)
13231337
atom_indices_for_frame = offset_only_positive(atom_indices_for_frame, atom_indices_offsets[..., None])
13241338

1339+
# just use a hack to remove any duplicated indices (ligands and modified biomolecules) in a row
1340+
1341+
atom_indices_for_frame = remove_consecutive_duplicate(atom_indices_for_frame)
1342+
13251343
# handle atom positions
13261344

13271345
atom_pos = i.atom_pos
@@ -1451,6 +1469,13 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
14511469
ss_rnas = list(i.ss_rna)
14521470
ss_dnas = list(i.ss_dna)
14531471

1472+
# handle atom positions - need atom positions for deriving frame of ligand for PAE
1473+
1474+
atom_pos = i.atom_pos
1475+
1476+
if isinstance(atom_pos, list):
1477+
atom_pos = torch.cat(atom_pos)
1478+
14541479
# any double stranded nucleic acids is added to single stranded lists with its reverse complement
14551480
# rc stands for reverse complement
14561481

@@ -1568,12 +1593,31 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
15681593
# convert ligands to rdchem.Mol
15691594

15701595
ligands = list(alphafold3_input.ligands)
1596+
15711597
mol_ligands = [
15721598
(mol_from_smile(ligand) if isinstance(ligand, str) else ligand) for ligand in ligands
15731599
]
15741600

15751601
molecule_ids.append(tensor([ligand_id] * len(mol_ligands)))
1576-
1602+
1603+
# handle frames for the ligands, which depends on knowing the atom positions (section 4.3.2)
1604+
1605+
if exists(atom_pos):
1606+
ligand_atom_pos_offset = 0
1607+
1608+
for mol in flatten([*mol_proteins, *mol_ss_rnas, *mol_ss_dnas]):
1609+
ligand_atom_pos_offset += mol.GetNumAtoms()
1610+
1611+
for mol_ligand in mol_ligands:
1612+
num_ligand_atoms = mol_ligand.GetNumAtoms()
1613+
ligand_atom_pos = atom_pos[ligand_atom_pos_offset:(ligand_atom_pos_offset + num_ligand_atoms)]
1614+
1615+
frames = get_frames_from_atom_pos(ligand_atom_pos, filter_colinear_pos = True)
1616+
1617+
atom_indices_for_frame.append(frames.tolist())
1618+
1619+
ligand_atom_pos_offset += num_ligand_atoms
1620+
15771621
# convert metal ions to rdchem.Mol
15781622

15791623
metal_ions = alphafold3_input.metal_ions
@@ -1760,10 +1804,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
17601804
molecule_atom_indices = tensor(molecule_atom_indices)
17611805
molecule_atom_indices = pad_to_len(molecule_atom_indices, num_tokens, value=-1)
17621806

1763-
# atom positions
1764-
1765-
atom_pos = i.atom_pos
1766-
17671807
# handle missing atom indices
17681808

17691809
missing_atom_indices = None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.18"
3+
version = "0.4.19"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)