Skip to content

Commit 24ae4a2

Browse files
authored
Update inputs.py (#94)
1 parent 1a7c4d5 commit 24ae4a2

File tree

1 file changed

+120
-59
lines changed

1 file changed

+120
-59
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 120 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict
77
from dataclasses import asdict, dataclass, field
88
from functools import partial
9+
from itertools import groupby
910
from typing import Any, Callable, List, Set, Tuple, Type
1011

1112
import einx
@@ -1405,6 +1406,57 @@ def get_token_index_from_composite_atom_id(
14051406
return np.where(chain_mask & res_mask & atom_mask)[0][atom_index]
14061407

14071408

1409+
@typecheck
1410+
def find_mismatched_symmetry(asym_ids: np.ndarray, entity_ids: np.ndarray, sym_ids: np.ndarray, chemid: np.ndarray) -> bool:
1411+
"""
1412+
Find mismatched symmetry in a biomolecule's asymmetry, entity, symmetry, and token chemical IDs.
1413+
1414+
This function compares the chemical IDs of (related) regions with the same entity ID
1415+
but different symmetry IDs. If the chemical IDs of these regions' matching asymmetric
1416+
chain ID regions are not equal, then their symmetry is "mismatched".
1417+
1418+
:param asym_ids: An array of asymmetric unit (i.e., chain) IDs for each token in the biomolecule.
1419+
:param entity_ids: An array of entity IDs for each token in the biomolecule.
1420+
:param sym_ids: An array of symmetry IDs for each token in the biomolecule.
1421+
:param chemid: An array of chemical IDs for each token in the biomolecule.
1422+
:return: A boolean indicating whether the symmetry IDs are mismatched.
1423+
"""
1424+
assert len(asym_ids) == len(entity_ids) == len(sym_ids) == len(chemid), (
1425+
f"The number of asymmetric unit IDs ({len(asym_ids)}), entity IDs ({len(entity_ids)}), symmetry IDs ({len(sym_ids)}), and chemical IDs ({len(chemid)}) do not match. "
1426+
"Please ensure that these input features are correctly paired."
1427+
)
1428+
1429+
# Create a combined array of tuples (asym_id, entity_id, sym_id, index)
1430+
combined = np.array(list(zip(asym_ids, entity_ids, sym_ids, range(len(entity_ids)))))
1431+
1432+
# Group by entity_id
1433+
grouped_by_entity = defaultdict(list)
1434+
for entity, group in groupby(combined, key=lambda x: x[1]):
1435+
grouped_by_entity[entity].extend(list(group))
1436+
1437+
# Compare regions with the same entity_id but different sym_id
1438+
for entity, group in grouped_by_entity.items():
1439+
# Group by sym_id within each entity_id group
1440+
grouped_by_sym = defaultdict(list)
1441+
for _, _, sym, idx in group:
1442+
grouped_by_sym[sym].append(idx)
1443+
1444+
# Compare chemid sequences for the asym_id regions of different sym_id groups within the same entity_id group
1445+
sym_ids_keys = list(grouped_by_sym.keys())
1446+
for i in range(len(sym_ids_keys)):
1447+
for j in range(i + 1, len(sym_ids_keys)):
1448+
indices1 = grouped_by_sym[sym_ids_keys[i]]
1449+
indices2 = grouped_by_sym[sym_ids_keys[j]]
1450+
indices1_asym_ids = np.unique(asym_ids[indices1])
1451+
indices2_asym_ids = np.unique(asym_ids[indices2])
1452+
chemid_seq1 = chemid[np.isin(asym_ids, indices1_asym_ids)]
1453+
chemid_seq2 = chemid[np.isin(asym_ids, indices2_asym_ids)]
1454+
if len(chemid_seq1) != len(chemid_seq2) or not np.array_equal(chemid_seq1, chemid_seq2):
1455+
return True
1456+
1457+
return False
1458+
1459+
14081460
@typecheck
14091461
def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
14101462
"""Convert a PDBInput to a MoleculeInput."""
@@ -1532,6 +1584,64 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
15321584
molecule_atom_indices = tensor(molecule_atom_indices)
15331585
distogram_atom_indices = tensor(distogram_atom_indices)
15341586

1587+
# constructing the additional_molecule_feats
1588+
# which is in turn used to derive relative positions
1589+
1590+
# residue_index - an arange that restarts at 1 for each chain - reuse biomol.residue_index here
1591+
# token_index - just an arange
1592+
# asym_id - unique id for each chain of a biomolecule - reuse chain_index here
1593+
# entity_id - unique id for each biomolecule sequence
1594+
# sym_id - unique id for each chain of the same biomolecule sequence
1595+
1596+
# entity ids
1597+
1598+
unrepeated_entity_sequences = defaultdict(int)
1599+
for entity_sequence in chain_seqs:
1600+
if entity_sequence in unrepeated_entity_sequences:
1601+
continue
1602+
unrepeated_entity_sequences[entity_sequence] = len(unrepeated_entity_sequences)
1603+
1604+
entity_idx = 0
1605+
entity_id_counts = []
1606+
unrepeated_entity_ids = []
1607+
for entity_sequence, chain_chem_type in zip(chain_seqs, chain_chem_types):
1608+
entity_mol = molecules[entity_idx]
1609+
entity_len = (
1610+
entity_mol.GetNumAtoms() if chain_chem_type == "ligand" else len(entity_sequence)
1611+
)
1612+
entity_idx += 1 if chain_chem_type == "ligand" else len(entity_sequence)
1613+
1614+
entity_id_counts.append(entity_len)
1615+
unrepeated_entity_ids.append(unrepeated_entity_sequences[entity_sequence])
1616+
1617+
entity_ids = torch.repeat_interleave(tensor(unrepeated_entity_ids), tensor(entity_id_counts))
1618+
1619+
# sym ids
1620+
1621+
unrepeated_sym_ids = []
1622+
unrepeated_sym_sequences = defaultdict(int)
1623+
for entity_sequence in chain_seqs:
1624+
unrepeated_sym_ids.append(unrepeated_sym_sequences[entity_sequence])
1625+
if entity_sequence in unrepeated_sym_sequences:
1626+
unrepeated_sym_sequences[entity_sequence] += 1
1627+
unrepeated_sym_ids = tensor(unrepeated_sym_ids)
1628+
1629+
sym_ids = torch.repeat_interleave(tensor(unrepeated_sym_ids), tensor(entity_id_counts))
1630+
1631+
# concat for all of additional_molecule_feats
1632+
1633+
additional_molecule_feats = torch.stack(
1634+
(
1635+
# NOTE: `Biomolecule.residue_index` is 1-based originally
1636+
torch.from_numpy(biomol.residue_index) - 1,
1637+
torch.arange(num_tokens),
1638+
torch.from_numpy(biomol.chain_index),
1639+
entity_ids,
1640+
sym_ids,
1641+
),
1642+
dim=-1,
1643+
)
1644+
15351645
# construct token bonds, which will be linearly connected for proteins
15361646
# and nucleic acids, but for ligands will have their atomic bond matrix
15371647
# (as ligands are atom resolution)
@@ -1584,12 +1694,21 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
15841694
polymer_offset += chain_len
15851695
ligand_offset += chain_len
15861696

1697+
# ascertain whether homomeric (e.g., bonded ligand) symmetry is preserved,
1698+
# which determines whether or not we use the mmCIF bond inputs (AF3 Section 5.1)
1699+
lacking_homomeric_symmetry = find_mismatched_symmetry(
1700+
biomol.chain_index,
1701+
entity_ids.numpy(),
1702+
sym_ids.numpy(),
1703+
biomol.chemid,
1704+
)
1705+
15871706
# ensure mmCIF polymer-ligand (i.e., protein/RNA/DNA-ligand) and ligand-ligand bonds
15881707
# (and bonds less than 2.4 Å) are installed in `MoleculeInput` during training only
15891708
# per the AF3 supplement (Table 5, `token_bonds`)
15901709
bond_atom_indices = defaultdict(int)
15911710
for bond in biomol.bonds:
1592-
if not i.training:
1711+
if not i.training or lacking_homomeric_symmetry:
15931712
continue
15941713

15951714
# determine bond type
@@ -1657,64 +1776,6 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
16571776
bond_atom_indices[ptnr1_atom_id] += 1
16581777
bond_atom_indices[ptnr2_atom_id] += 1
16591778

1660-
# constructing the additional_molecule_feats
1661-
# which is in turn used to derive relative positions
1662-
1663-
# residue_index - an arange that restarts at 1 for each chain - reuse biomol.residue_index here
1664-
# token_index - just an arange
1665-
# asym_id - unique id for each chain of a biomolecule - reuse chain_index here
1666-
# entity_id - unique id for each biomolecule sequence
1667-
# sym_id - unique id for each chain of the same biomolecule sequence
1668-
1669-
# entity ids
1670-
1671-
unrepeated_entity_sequences = defaultdict(int)
1672-
for entity_sequence in chain_seqs:
1673-
if entity_sequence in unrepeated_entity_sequences:
1674-
continue
1675-
unrepeated_entity_sequences[entity_sequence] = len(unrepeated_entity_sequences)
1676-
1677-
entity_idx = 0
1678-
entity_id_counts = []
1679-
unrepeated_entity_ids = []
1680-
for entity_sequence, chain_chem_type in zip(chain_seqs, chain_chem_types):
1681-
entity_mol = molecules[entity_idx]
1682-
entity_len = (
1683-
entity_mol.GetNumAtoms() if chain_chem_type == "ligand" else len(entity_sequence)
1684-
)
1685-
entity_idx += 1 if chain_chem_type == "ligand" else len(entity_sequence)
1686-
1687-
entity_id_counts.append(entity_len)
1688-
unrepeated_entity_ids.append(unrepeated_entity_sequences[entity_sequence])
1689-
1690-
entity_ids = torch.repeat_interleave(tensor(unrepeated_entity_ids), tensor(entity_id_counts))
1691-
1692-
# sym ids
1693-
1694-
unrepeated_sym_ids = []
1695-
unrepeated_sym_sequences = defaultdict(int)
1696-
for entity_sequence in chain_seqs:
1697-
unrepeated_sym_ids.append(unrepeated_sym_sequences[entity_sequence])
1698-
if entity_sequence in unrepeated_sym_sequences:
1699-
unrepeated_sym_sequences[entity_sequence] += 1
1700-
unrepeated_sym_ids = tensor(unrepeated_sym_ids)
1701-
1702-
sym_ids = torch.repeat_interleave(tensor(unrepeated_sym_ids), tensor(entity_id_counts))
1703-
1704-
# concat for all of additional_molecule_feats
1705-
1706-
additional_molecule_feats = torch.stack(
1707-
(
1708-
# NOTE: `Biomolecule.residue_index` is 1-based originally
1709-
torch.from_numpy(biomol.residue_index) - 1,
1710-
torch.arange(num_tokens),
1711-
torch.from_numpy(biomol.chain_index),
1712-
entity_ids,
1713-
sym_ids,
1714-
),
1715-
dim=-1,
1716-
)
1717-
17181779
# handle missing atom indices
17191780
missing_atom_indices = None
17201781
missing_token_indices = None

0 commit comments

Comments
 (0)