Skip to content

Commit dea7ba1

Browse files
authored
Generalize CovalentBond to Bond to prepare for bond featurization during model training (#72)
* Update biomolecule.py * Update mmcif_parsing.py * Update filter_pdb_mmcifs.py
1 parent a05aecc commit dea7ba1

File tree

3 files changed

+131
-42
lines changed

3 files changed

+131
-42
lines changed

alphafold3_pytorch/common/biomolecule.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"_pdbx_struct_assembly.",
3737
"_pdbx_struct_assembly_gen.",
3838
"_struct_asym.",
39+
"_struct_conn.",
3940
]
4041
MMCIF_PREFIXES_TO_DROP_POST_AF3 = MMCIF_PREFIXES_TO_DROP_POST_PARSING + [
4142
"_citation.",
@@ -90,12 +91,18 @@ class Biomolecule:
9091
# a protein (0), RNA (1), DNA (2), or ligand (3) residue.
9192
chemtype: np.ndarray # [num_res]
9293

94+
# Bonds between atoms in the biomolecule.
95+
bonds: Optional[List[mmcif_parsing.Bond]] # [num_bonds]
96+
9397
# Atom name-chain ID-residue ID tuples for each (e.g. ligand) "pseudoresidue" of each residue in each chain.
9498
# This is used to group "pseudoresidues" (e.g., ligand atoms) by parent residue.
9599
unique_res_atom_names: Optional[
96100
List[Tuple[List[List[str]], str, int]]
97101
] # [num_res, num_pseudoresidues, num_atoms]
98102

103+
# Mapping from (original) author chain ID-residue name-residue ID (CRI) tuples to (new) author CRI tuples.
104+
author_cri_to_new_cri: Dict[Tuple[str, str, int], Tuple[str, str, int]] # [num_res]
105+
99106
# Chemical component details of each residue as a unique `ChemComp` object.
100107
# This is used to determine the biomolecule's unique chemical IDs, names, types, etc.
101108
# N.b., this is primarily used to record chemical component metadata
@@ -125,7 +132,9 @@ def __add__(self, other: "Biomolecule") -> "Biomolecule":
125132
b_factors=np.concatenate([self.b_factors, other.b_factors], axis=0),
126133
chemid=np.concatenate([self.chemid, other.chemid], axis=0),
127134
chemtype=np.concatenate([self.chemtype, other.chemtype], axis=0),
135+
bonds=list(dict.fromkeys(self.bonds + other.bonds)),
128136
unique_res_atom_names=self.unique_res_atom_names + other.unique_res_atom_names,
137+
author_cri_to_new_cri={**self.author_cri_to_new_cri, **other.author_cri_to_new_cri},
129138
chem_comp_table=self.chem_comp_table.union(other.chem_comp_table),
130139
entity_to_chain=deep_merge_dicts(
131140
self.entity_to_chain, other.entity_to_chain, value_op="union"
@@ -168,11 +177,22 @@ def subset_chains(self, subset_chain_ids: List[str]) -> "Biomolecule":
168177
b_factors=self.b_factors[chain_mask],
169178
chemid=self.chemid[chain_mask],
170179
chemtype=self.chemtype[chain_mask],
180+
bonds=[
181+
bond
182+
for bond in self.bonds
183+
if bond.ptnr1_auth_asym_id in subset_chain_ids
184+
and bond.ptnr2_auth_asym_id in subset_chain_ids
185+
],
171186
unique_res_atom_names=[
172187
unique_res_atom_names
173188
for unique_res_atom_names in self.unique_res_atom_names
174189
if unique_res_atom_names[1] in subset_chain_ids
175190
],
191+
author_cri_to_new_cri={
192+
author_cri: new_cri
193+
for author_cri, new_cri in self.author_cri_to_new_cri.items()
194+
if new_cri[0] in subset_chain_index_mapping
195+
},
176196
chem_comp_table=self.chem_comp_table,
177197
entity_to_chain=entity_to_chain,
178198
mmcif_to_author_chain=mmcif_to_author_chain,
@@ -191,11 +211,13 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
191211
b_factors=np.tile(self.b_factors, (coord.shape[0], 1, 1)).reshape(-1, 47),
192212
chemid=np.tile(self.chemid, (coord.shape[0], 1)).reshape(-1),
193213
chemtype=np.tile(self.chemtype, (coord.shape[0], 1)).reshape(-1),
214+
bonds=self.bonds,
194215
unique_res_atom_names=[
195216
unique_res_atom_names
196217
for _ in range(coord.shape[0])
197218
for unique_res_atom_names in self.unique_res_atom_names
198219
],
220+
author_cri_to_new_cri=self.author_cri_to_new_cri,
199221
chem_comp_table=self.chem_comp_table,
200222
entity_to_chain=self.entity_to_chain,
201223
mmcif_to_author_chain=self.mmcif_to_author_chain,
@@ -320,6 +342,7 @@ def _from_mmcif_object(
320342
residue_index = []
321343
chain_ids = []
322344
b_factors = []
345+
author_cri_to_new_cri = {}
323346

324347
for chain in model:
325348
if exists(chain_id) and chain.id != chain_id:
@@ -349,7 +372,11 @@ def _from_mmcif_object(
349372
for atom in res:
350373
if is_polymer_residue and atom.name not in residue_constants.atom_types_set:
351374
continue
352-
elif is_peptide_residue and atom.name.upper() == "SE" and res.get_resname() == "MSE":
375+
elif (
376+
is_peptide_residue
377+
and atom.name.upper() == "SE"
378+
and res.get_resname() == "MSE"
379+
):
353380
# Put the coords of the selenium atom in the sulphur column.
354381
pos[residue_constants.atom_order["SD"]] = atom.coord
355382
mask[residue_constants.atom_order["SD"]] = 1.0
@@ -370,7 +397,9 @@ def _from_mmcif_object(
370397
if (
371398
res.get_resname() == "ARG"
372399
and all(mask[atom_index] for atom_index in (cd, nh1, nh2))
373-
and (np.linalg.norm(pos[nh1] - pos[cd]) > np.linalg.norm(pos[nh2] - pos[cd]))
400+
and (
401+
np.linalg.norm(pos[nh1] - pos[cd]) > np.linalg.norm(pos[nh2] - pos[cd])
402+
)
374403
):
375404
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
376405
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
@@ -387,6 +416,11 @@ def _from_mmcif_object(
387416
residue_index.append(res_index + 1)
388417
chain_ids.append(chain.id)
389418
b_factors.append(res_b_factors)
419+
author_cri_to_new_cri[(chain.id, res.resname, res.id[1])] = (
420+
chain.id,
421+
res.resname,
422+
res_index + 1,
423+
)
390424
if res.resname == residue_constants.unk_restype:
391425
# If the polymer residue is unknown, then it is of the corresponding unknown polymer residue type.
392426
residue_chem_comp_details.add(
@@ -426,6 +460,12 @@ def _from_mmcif_object(
426460
chain_ids.append(chain.id)
427461
b_factors.append(res_b_factors)
428462

463+
author_cri_to_new_cri[(chain.id, res.resname, res.id[1])] = (
464+
chain.id,
465+
res.resname,
466+
res_index + 1,
467+
)
468+
429469
if res.resname == residue_constants.unk_restype:
430470
# If the ligand residue is unknown, then it is of the unknown ligand residue type.
431471
residue_chem_comp_details.add(
@@ -473,7 +513,9 @@ def _from_mmcif_object(
473513
b_factors=np.array(b_factors),
474514
chemid=np.array(chemid),
475515
chemtype=np.array(chemtype),
516+
bonds=mmcif_object.bonds,
476517
unique_res_atom_names=unique_res_atom_names,
518+
author_cri_to_new_cri=author_cri_to_new_cri,
477519
chem_comp_table=residue_chem_comp_details,
478520
entity_to_chain=entity_to_chain,
479521
mmcif_to_author_chain=mmcif_to_author_chain,
@@ -605,6 +647,8 @@ def to_mmcif(
605647
b_factors = biomol.b_factors
606648
chemid = biomol.chemid
607649
chemtype = biomol.chemtype
650+
bonds = biomol.bonds
651+
author_cri_to_new_cri = biomol.author_cri_to_new_cri
608652
entity_id_to_chain_ids = biomol.entity_to_chain
609653
mmcif_to_author_chain_ids = biomol.mmcif_to_author_chain
610654
orig_mmcif_metadata = biomol.mmcif_metadata
@@ -751,6 +795,45 @@ def to_mmcif(
751795
str(pdbx_struct_assembly_oligomeric_count[assembly_id])
752796
)
753797

798+
# Populate the _struct_conn table.
799+
for bond in bonds:
800+
# Skip bonds between residues that have previously been filtered out.
801+
ptnr1_key = (
802+
bond.ptnr1_auth_asym_id,
803+
bond.ptnr1_auth_comp_id,
804+
int(bond.ptnr1_auth_seq_id),
805+
)
806+
ptnr2_key = (
807+
bond.ptnr2_auth_asym_id,
808+
bond.ptnr2_auth_comp_id,
809+
int(bond.ptnr2_auth_seq_id),
810+
)
811+
if ptnr1_key not in author_cri_to_new_cri or ptnr2_key not in author_cri_to_new_cri:
812+
continue
813+
# Partner 1
814+
ptnr1_mapping = author_cri_to_new_cri[ptnr1_key]
815+
mmcif_dict["_struct_conn.ptnr1_auth_seq_id"].append(
816+
str(ptnr1_mapping[2])
817+
) # Reindex ptnr1 residue ID.
818+
mmcif_dict["_struct_conn.ptnr1_auth_comp_id"].append(bond.ptnr1_auth_comp_id)
819+
mmcif_dict["_struct_conn.ptnr1_auth_asym_id"].append(bond.ptnr1_auth_asym_id)
820+
mmcif_dict["_struct_conn.ptnr1_label_atom_id"].append(bond.ptnr1_label_atom_id)
821+
mmcif_dict["_struct_conn.pdbx_ptnr1_label_alt_id"].append(bond.pdbx_ptnr1_label_alt_id)
822+
# Partner 2
823+
ptnr2_mapping = author_cri_to_new_cri[ptnr2_key]
824+
mmcif_dict["_struct_conn.ptnr2_auth_seq_id"].append(
825+
str(ptnr2_mapping[2])
826+
) # Reindex ptnr2 residue ID.
827+
mmcif_dict["_struct_conn.ptnr2_auth_comp_id"].append(bond.ptnr2_auth_comp_id)
828+
mmcif_dict["_struct_conn.ptnr2_auth_asym_id"].append(bond.ptnr2_auth_asym_id)
829+
mmcif_dict["_struct_conn.ptnr2_label_atom_id"].append(bond.ptnr2_label_atom_id)
830+
mmcif_dict["_struct_conn.pdbx_ptnr2_label_alt_id"].append(bond.pdbx_ptnr2_label_alt_id)
831+
# Connection metadata
832+
mmcif_dict["_struct_conn.pdbx_leaving_atom_flag"].append(bond.pdbx_leaving_atom_flag)
833+
mmcif_dict["_struct_conn.pdbx_dist_value"].append(bond.pdbx_dist_value)
834+
mmcif_dict["_struct_conn.pdbx_role"].append(bond.pdbx_role)
835+
mmcif_dict["_struct_conn.conn_type_id"].append(bond.conn_type_id)
836+
754837
# Populate the _chem_comp table.
755838
for chem_comp in biomol.chem_comp_table:
756839
mmcif_dict["_chem_comp.id"].append(chem_comp.id)

alphafold3_pytorch/data/mmcif_parsing.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class AtomSite:
6565

6666

6767
@dataclasses.dataclass(frozen=True)
68-
class CovalentBond:
69-
"""Represents a covalent bond between two atoms."""
68+
class Bond:
69+
"""Represents a structural connection between two atoms."""
7070

7171
ptnr1_auth_seq_id: str
7272
ptnr1_auth_comp_id: str
@@ -80,7 +80,9 @@ class CovalentBond:
8080
ptnr2_label_atom_id: str
8181
pdbx_ptnr2_label_alt_id: str
8282

83-
leaving_atom_flag: str
83+
pdbx_leaving_atom_flag: str
84+
pdbx_dist_value: str
85+
pdbx_role: str
8486
conn_type_id: str
8587

8688

@@ -127,7 +129,7 @@ class MmcifObject:
127129
{1: ['A', 'B']}
128130
mmcif_to_author_chain: Dict mapping internal mmCIF chain ids to author chain ids. E.g.
129131
{'A': 'B', 'B', 'B'}
130-
covalent_bonds: List of CovalentBond.
132+
bonds: List of Bond objects.
131133
raw_string: The raw string used to construct the MmcifObject.
132134
atoms_to_remove: Optional set of atoms to remove.
133135
residues_to_remove: Optional set of residues to remove.
@@ -143,7 +145,7 @@ class MmcifObject:
143145
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
144146
entity_to_chain: Mapping[int, Sequence[str]]
145147
mmcif_to_author_chain: Mapping[str, str]
146-
covalent_bonds: Sequence[CovalentBond]
148+
bonds: Sequence[Bond]
147149
raw_string: Any
148150
atoms_to_remove: Set[AtomFullId]
149151
residues_to_remove: Set[ResidueFullId]
@@ -541,8 +543,8 @@ def parse(
541543
for entity_id, chains in mmcif_entity_to_author_chain_mappings.items()
542544
}
543545

544-
# Identify all covalent bonds.
545-
covalent_bonds = _get_covalent_bond_list(parsed_info)
546+
# Identify all bonds.
547+
bonds = _get_bond_list(parsed_info)
546548

547549
mmcif_object = MmcifObject(
548550
file_id=file_id,
@@ -554,7 +556,7 @@ def parse(
554556
seqres_to_structure=seq_to_structure_mappings,
555557
entity_to_chain=entity_to_chain,
556558
mmcif_to_author_chain=mmcif_to_author_chain_id,
557-
covalent_bonds=covalent_bonds,
559+
bonds=bonds,
558560
raw_string=parsed_info,
559561
atoms_to_remove=set(),
560562
residues_to_remove=set(),
@@ -631,12 +633,12 @@ def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
631633
]
632634

633635

634-
def _get_covalent_bond_list(parsed_info: MmCIFDict) -> Sequence[CovalentBond]:
635-
"""Returns list of covalent bonds present in the structure."""
636+
def _get_bond_list(parsed_info: MmCIFDict) -> Sequence[Bond]:
637+
"""Returns list of bonds present in the structure."""
636638
return [
637-
# Collect unique (partner) atom metadata required for each covalent bond
639+
# Collect unique (partner) atom metadata required for each bond
638640
# per https://mmcif.wwpdb.org/docs/sw-examples/python/html/connections3.html.
639-
CovalentBond(*conn)
641+
Bond(*conn)
640642
for conn in zip( # pylint:disable=g-complex-comprehension
641643
# Partner 1
642644
parsed_info.get("_struct_conn.ptnr1_auth_seq_id", []),
@@ -652,9 +654,11 @@ def _get_covalent_bond_list(parsed_info: MmCIFDict) -> Sequence[CovalentBond]:
652654
parsed_info.get("_struct_conn.pdbx_ptnr2_label_alt_id", []),
653655
# Connection metadata
654656
parsed_info.get("_struct_conn.pdbx_leaving_atom_flag", []),
657+
parsed_info.get("_struct_conn.pdbx_dist_value", []),
658+
parsed_info.get("_struct_conn.pdbx_role", []),
655659
parsed_info.get("_struct_conn.conn_type_id", []),
656660
)
657-
if len(conn[-1]) and conn[-1].lower() == "covale"
661+
if len(conn[-1]) > 0
658662
]
659663

660664

0 commit comments

Comments
 (0)