|
6 | 6 | from collections import defaultdict |
7 | 7 | from dataclasses import asdict, dataclass, field |
8 | 8 | from functools import partial |
| 9 | +from itertools import groupby |
9 | 10 | from typing import Any, Callable, List, Set, Tuple, Type |
10 | 11 |
|
11 | 12 | import einx |
@@ -1405,6 +1406,57 @@ def get_token_index_from_composite_atom_id( |
1405 | 1406 | return np.where(chain_mask & res_mask & atom_mask)[0][atom_index] |
1406 | 1407 |
|
1407 | 1408 |
|
| 1409 | +@typecheck |
| 1410 | +def find_mismatched_symmetry(asym_ids: np.ndarray, entity_ids: np.ndarray, sym_ids: np.ndarray, chemid: np.ndarray) -> bool: |
| 1411 | + """ |
| 1412 | + Find mismatched symmetry in a biomolecule's asymmetry, entity, symmetry, and token chemical IDs. |
| 1413 | + |
| 1414 | + This function compares the chemical IDs of (related) regions with the same entity ID |
| 1415 | + but different symmetry IDs. If the chemical IDs of these regions' matching asymmetric |
| 1416 | + chain ID regions are not equal, then their symmetry is "mismatched". |
| 1417 | +
|
| 1418 | + :param asym_ids: An array of asymmetric unit (i.e., chain) IDs for each token in the biomolecule. |
| 1419 | + :param entity_ids: An array of entity IDs for each token in the biomolecule. |
| 1420 | + :param sym_ids: An array of symmetry IDs for each token in the biomolecule. |
| 1421 | + :param chemid: An array of chemical IDs for each token in the biomolecule. |
| 1422 | + :return: A boolean indicating whether the symmetry IDs are mismatched. |
| 1423 | + """ |
| 1424 | + assert len(asym_ids) == len(entity_ids) == len(sym_ids) == len(chemid), ( |
| 1425 | + f"The number of asymmetric unit IDs ({len(asym_ids)}), entity IDs ({len(entity_ids)}), symmetry IDs ({len(sym_ids)}), and chemical IDs ({len(chemid)}) do not match. " |
| 1426 | + "Please ensure that these input features are correctly paired." |
| 1427 | + ) |
| 1428 | + |
| 1429 | + # Create a combined array of tuples (asym_id, entity_id, sym_id, index) |
| 1430 | + combined = np.array(list(zip(asym_ids, entity_ids, sym_ids, range(len(entity_ids))))) |
| 1431 | + |
| 1432 | + # Group by entity_id |
| 1433 | + grouped_by_entity = defaultdict(list) |
| 1434 | + for entity, group in groupby(combined, key=lambda x: x[1]): |
| 1435 | + grouped_by_entity[entity].extend(list(group)) |
| 1436 | + |
| 1437 | + # Compare regions with the same entity_id but different sym_id |
| 1438 | + for entity, group in grouped_by_entity.items(): |
| 1439 | + # Group by sym_id within each entity_id group |
| 1440 | + grouped_by_sym = defaultdict(list) |
| 1441 | + for _, _, sym, idx in group: |
| 1442 | + grouped_by_sym[sym].append(idx) |
| 1443 | + |
| 1444 | + # Compare chemid sequences for the asym_id regions of different sym_id groups within the same entity_id group |
| 1445 | + sym_ids_keys = list(grouped_by_sym.keys()) |
| 1446 | + for i in range(len(sym_ids_keys)): |
| 1447 | + for j in range(i + 1, len(sym_ids_keys)): |
| 1448 | + indices1 = grouped_by_sym[sym_ids_keys[i]] |
| 1449 | + indices2 = grouped_by_sym[sym_ids_keys[j]] |
| 1450 | + indices1_asym_ids = np.unique(asym_ids[indices1]) |
| 1451 | + indices2_asym_ids = np.unique(asym_ids[indices2]) |
| 1452 | + chemid_seq1 = chemid[np.isin(asym_ids, indices1_asym_ids)] |
| 1453 | + chemid_seq2 = chemid[np.isin(asym_ids, indices2_asym_ids)] |
| 1454 | + if len(chemid_seq1) != len(chemid_seq2) or not np.array_equal(chemid_seq1, chemid_seq2): |
| 1455 | + return True |
| 1456 | + |
| 1457 | + return False |
| 1458 | + |
| 1459 | + |
1408 | 1460 | @typecheck |
1409 | 1461 | def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput: |
1410 | 1462 | """Convert a PDBInput to a MoleculeInput.""" |
@@ -1532,6 +1584,64 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput: |
1532 | 1584 | molecule_atom_indices = tensor(molecule_atom_indices) |
1533 | 1585 | distogram_atom_indices = tensor(distogram_atom_indices) |
1534 | 1586 |
|
| 1587 | + # constructing the additional_molecule_feats |
| 1588 | + # which is in turn used to derive relative positions |
| 1589 | + |
| 1590 | + # residue_index - an arange that restarts at 1 for each chain - reuse biomol.residue_index here |
| 1591 | + # token_index - just an arange |
| 1592 | + # asym_id - unique id for each chain of a biomolecule - reuse chain_index here |
| 1593 | + # entity_id - unique id for each biomolecule sequence |
| 1594 | + # sym_id - unique id for each chain of the same biomolecule sequence |
| 1595 | + |
| 1596 | + # entity ids |
| 1597 | + |
| 1598 | + unrepeated_entity_sequences = defaultdict(int) |
| 1599 | + for entity_sequence in chain_seqs: |
| 1600 | + if entity_sequence in unrepeated_entity_sequences: |
| 1601 | + continue |
| 1602 | + unrepeated_entity_sequences[entity_sequence] = len(unrepeated_entity_sequences) |
| 1603 | + |
| 1604 | + entity_idx = 0 |
| 1605 | + entity_id_counts = [] |
| 1606 | + unrepeated_entity_ids = [] |
| 1607 | + for entity_sequence, chain_chem_type in zip(chain_seqs, chain_chem_types): |
| 1608 | + entity_mol = molecules[entity_idx] |
| 1609 | + entity_len = ( |
| 1610 | + entity_mol.GetNumAtoms() if chain_chem_type == "ligand" else len(entity_sequence) |
| 1611 | + ) |
| 1612 | + entity_idx += 1 if chain_chem_type == "ligand" else len(entity_sequence) |
| 1613 | + |
| 1614 | + entity_id_counts.append(entity_len) |
| 1615 | + unrepeated_entity_ids.append(unrepeated_entity_sequences[entity_sequence]) |
| 1616 | + |
| 1617 | + entity_ids = torch.repeat_interleave(tensor(unrepeated_entity_ids), tensor(entity_id_counts)) |
| 1618 | + |
| 1619 | + # sym ids |
| 1620 | + |
| 1621 | + unrepeated_sym_ids = [] |
| 1622 | + unrepeated_sym_sequences = defaultdict(int) |
| 1623 | + for entity_sequence in chain_seqs: |
| 1624 | + unrepeated_sym_ids.append(unrepeated_sym_sequences[entity_sequence]) |
| 1625 | + if entity_sequence in unrepeated_sym_sequences: |
| 1626 | + unrepeated_sym_sequences[entity_sequence] += 1 |
| 1627 | + unrepeated_sym_ids = tensor(unrepeated_sym_ids) |
| 1628 | + |
| 1629 | + sym_ids = torch.repeat_interleave(tensor(unrepeated_sym_ids), tensor(entity_id_counts)) |
| 1630 | + |
| 1631 | + # concat for all of additional_molecule_feats |
| 1632 | + |
| 1633 | + additional_molecule_feats = torch.stack( |
| 1634 | + ( |
| 1635 | + # NOTE: `Biomolecule.residue_index` is 1-based originally |
| 1636 | + torch.from_numpy(biomol.residue_index) - 1, |
| 1637 | + torch.arange(num_tokens), |
| 1638 | + torch.from_numpy(biomol.chain_index), |
| 1639 | + entity_ids, |
| 1640 | + sym_ids, |
| 1641 | + ), |
| 1642 | + dim=-1, |
| 1643 | + ) |
| 1644 | + |
1535 | 1645 | # construct token bonds, which will be linearly connected for proteins |
1536 | 1646 | # and nucleic acids, but for ligands will have their atomic bond matrix |
1537 | 1647 | # (as ligands are atom resolution) |
@@ -1584,12 +1694,21 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput: |
1584 | 1694 | polymer_offset += chain_len |
1585 | 1695 | ligand_offset += chain_len |
1586 | 1696 |
|
| 1697 | + # ascertain whether homomeric (e.g., bonded ligand) symmetry is preserved, |
| 1698 | + # which determines whether or not we use the mmCIF bond inputs (AF3 Section 5.1) |
| 1699 | + lacking_homomeric_symmetry = find_mismatched_symmetry( |
| 1700 | + biomol.chain_index, |
| 1701 | + entity_ids.numpy(), |
| 1702 | + sym_ids.numpy(), |
| 1703 | + biomol.chemid, |
| 1704 | + ) |
| 1705 | + |
1587 | 1706 | # ensure mmCIF polymer-ligand (i.e., protein/RNA/DNA-ligand) and ligand-ligand bonds |
1588 | 1707 | # (and bonds less than 2.4 Å) are installed in `MoleculeInput` during training only |
1589 | 1708 | # per the AF3 supplement (Table 5, `token_bonds`) |
1590 | 1709 | bond_atom_indices = defaultdict(int) |
1591 | 1710 | for bond in biomol.bonds: |
1592 | | - if not i.training: |
| 1711 | + if not i.training or lacking_homomeric_symmetry: |
1593 | 1712 | continue |
1594 | 1713 |
|
1595 | 1714 | # determine bond type |
@@ -1657,64 +1776,6 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput: |
1657 | 1776 | bond_atom_indices[ptnr1_atom_id] += 1 |
1658 | 1777 | bond_atom_indices[ptnr2_atom_id] += 1 |
1659 | 1778 |
|
1660 | | - # constructing the additional_molecule_feats |
1661 | | - # which is in turn used to derive relative positions |
1662 | | - |
1663 | | - # residue_index - an arange that restarts at 1 for each chain - reuse biomol.residue_index here |
1664 | | - # token_index - just an arange |
1665 | | - # asym_id - unique id for each chain of a biomolecule - reuse chain_index here |
1666 | | - # entity_id - unique id for each biomolecule sequence |
1667 | | - # sym_id - unique id for each chain of the same biomolecule sequence |
1668 | | - |
1669 | | - # entity ids |
1670 | | - |
1671 | | - unrepeated_entity_sequences = defaultdict(int) |
1672 | | - for entity_sequence in chain_seqs: |
1673 | | - if entity_sequence in unrepeated_entity_sequences: |
1674 | | - continue |
1675 | | - unrepeated_entity_sequences[entity_sequence] = len(unrepeated_entity_sequences) |
1676 | | - |
1677 | | - entity_idx = 0 |
1678 | | - entity_id_counts = [] |
1679 | | - unrepeated_entity_ids = [] |
1680 | | - for entity_sequence, chain_chem_type in zip(chain_seqs, chain_chem_types): |
1681 | | - entity_mol = molecules[entity_idx] |
1682 | | - entity_len = ( |
1683 | | - entity_mol.GetNumAtoms() if chain_chem_type == "ligand" else len(entity_sequence) |
1684 | | - ) |
1685 | | - entity_idx += 1 if chain_chem_type == "ligand" else len(entity_sequence) |
1686 | | - |
1687 | | - entity_id_counts.append(entity_len) |
1688 | | - unrepeated_entity_ids.append(unrepeated_entity_sequences[entity_sequence]) |
1689 | | - |
1690 | | - entity_ids = torch.repeat_interleave(tensor(unrepeated_entity_ids), tensor(entity_id_counts)) |
1691 | | - |
1692 | | - # sym ids |
1693 | | - |
1694 | | - unrepeated_sym_ids = [] |
1695 | | - unrepeated_sym_sequences = defaultdict(int) |
1696 | | - for entity_sequence in chain_seqs: |
1697 | | - unrepeated_sym_ids.append(unrepeated_sym_sequences[entity_sequence]) |
1698 | | - if entity_sequence in unrepeated_sym_sequences: |
1699 | | - unrepeated_sym_sequences[entity_sequence] += 1 |
1700 | | - unrepeated_sym_ids = tensor(unrepeated_sym_ids) |
1701 | | - |
1702 | | - sym_ids = torch.repeat_interleave(tensor(unrepeated_sym_ids), tensor(entity_id_counts)) |
1703 | | - |
1704 | | - # concat for all of additional_molecule_feats |
1705 | | - |
1706 | | - additional_molecule_feats = torch.stack( |
1707 | | - ( |
1708 | | - # NOTE: `Biomolecule.residue_index` is 1-based originally |
1709 | | - torch.from_numpy(biomol.residue_index) - 1, |
1710 | | - torch.arange(num_tokens), |
1711 | | - torch.from_numpy(biomol.chain_index), |
1712 | | - entity_ids, |
1713 | | - sym_ids, |
1714 | | - ), |
1715 | | - dim=-1, |
1716 | | - ) |
1717 | | - |
1718 | 1779 | # handle missing atom indices |
1719 | 1780 | missing_atom_indices = None |
1720 | 1781 | missing_token_indices = None |
|
0 commit comments