Skip to content

Commit dc2aed6

Browse files
authored
Crop MSA features during training (#185)
* Update biomolecule.py * Update data_pipeline.py * Update data_utils.py * Update inputs.py
1 parent 6388cfc commit dc2aed6

File tree

4 files changed

+161
-114
lines changed

4 files changed

+161
-114
lines changed

alphafold3_pytorch/common/biomolecule.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,18 @@ def __add__(self, other: "Biomolecule") -> "Biomolecule":
142142
chemtype=np.concatenate([self.chemtype, other.chemtype], axis=0),
143143
bonds=list(dict.fromkeys(self.bonds + other.bonds)),
144144
unique_res_atom_names=self.unique_res_atom_names + other.unique_res_atom_names,
145-
author_cri_to_new_cri={**self.author_cri_to_new_cri, **other.author_cri_to_new_cri},
145+
author_cri_to_new_cri={
146+
**self.author_cri_to_new_cri,
147+
**other.author_cri_to_new_cri,
148+
},
146149
chem_comp_table=self.chem_comp_table.union(other.chem_comp_table),
147150
entity_to_chain=deep_merge_dicts(
148151
self.entity_to_chain, other.entity_to_chain, value_op="union"
149152
),
150153
mmcif_to_author_chain=deep_merge_dicts(
151-
self.mmcif_to_author_chain, other.mmcif_to_author_chain, value_op="union"
154+
self.mmcif_to_author_chain,
155+
other.mmcif_to_author_chain,
156+
value_op="union",
152157
),
153158
mmcif_metadata={**self.mmcif_metadata, **other.mmcif_metadata},
154159
)
@@ -237,10 +242,8 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
237242
def crop_chains_with_masks(
238243
self, chain_ids_and_lengths: List[Tuple[str, int]], crop_masks: List[np.ndarray]
239244
) -> "Biomolecule":
240-
"""
241-
Crop the chains and metadata within a Biomolecule
242-
to only include the specified chain residues.
243-
"""
245+
"""Crop the chains and metadata within a Biomolecule to only include the specified chain
246+
residues."""
244247
assert len(chain_ids_and_lengths) == len(
245248
crop_masks
246249
), "The number of chains and crop masks must be equal."
@@ -320,7 +323,10 @@ def crop_chains_with_masks(
320323
unique_res_atom_names
321324
for unique_res_atom_names in self.unique_res_atom_names
322325
if unique_res_atom_names[1] not in chains_to_remove
323-
and (subset_chain_id_mapping[unique_res_atom_names[1]], unique_res_atom_names[2])
326+
and (
327+
subset_chain_id_mapping[unique_res_atom_names[1]],
328+
unique_res_atom_names[2],
329+
)
324330
in subset_chain_residue_mapping
325331
],
326332
author_cri_to_new_cri={
@@ -334,15 +340,13 @@ def crop_chains_with_masks(
334340
mmcif_metadata=self.mmcif_metadata,
335341
)
336342

337-
def contiguous_crop(self, n_res: int = 384) -> "Biomolecule":
338-
"""
339-
Crop a Biomolecule to only include contiguous
340-
polymer residues and/or ligand atoms for each chain.
341-
"""
343+
def contiguous_crop(self, n_res: int = 384) -> Tuple["Biomolecule", List[Tuple[str, int]], List[np.ndarray]]:
344+
"""Crop a Biomolecule to only include contiguous polymer residues and/or ligand atoms for
345+
each chain."""
342346
chain_ids_and_lengths = list(collections.Counter(self.chain_id).items())
343347
random.shuffle(chain_ids_and_lengths)
344348
crop_masks = create_contiguous_crop_masks(chain_ids_and_lengths, n_res)
345-
return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks)
349+
return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks), chain_ids_and_lengths, crop_masks
346350

347351
def spatial_crop(
348352
self,
@@ -351,11 +355,9 @@ def spatial_crop(
351355
chain_1: Optional[str] = None,
352356
chain_2: Optional[str] = None,
353357
interface_distance_threshold: float = 15.0,
354-
) -> "Biomolecule":
355-
"""
356-
Crop a Biomolecule to only include polymer residues and ligand atoms
357-
near a (random) reference atom within a sampled chain/interface.
358-
"""
358+
) -> Tuple["Biomolecule", List[Tuple[str, int]], List[np.ndarray]]:
359+
"""Crop a Biomolecule to only include polymer residues and ligand atoms near a (random)
360+
reference atom within a sampled chain/interface."""
359361

360362
# curate a list of candidate token center atoms from which to sample a reference atom
361363

@@ -430,7 +432,7 @@ def spatial_crop(
430432

431433
# sample a reference atom for spatial cropping
432434

433-
reference_atom_index = random.choice(token_center_atom_indices).item()
435+
reference_atom_index = random.choice(token_center_atom_indices).item() # nosec
434436

435437
# perform spatial cropping according to reference atom proximity
436438

@@ -441,7 +443,7 @@ def spatial_crop(
441443
reference_atom_index,
442444
n_res,
443445
)
444-
return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks)
446+
return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks), chain_ids_and_lengths, crop_masks
445447

446448
def crop(
447449
self,
@@ -451,7 +453,7 @@ def crop(
451453
n_res: int = 384,
452454
chain_1: str | None = None,
453455
chain_2: str | None = None,
454-
) -> "Biomolecule":
456+
) -> Tuple["Biomolecule", List[Tuple[str, int]], List[np.ndarray]]:
455457
"""Crop a Biomolecule using a randomly-sampled cropping function."""
456458
n_res = min(n_res, len(self.atom_mask))
457459
if exists(chain_1) and exists(chain_2):
@@ -481,16 +483,16 @@ def crop(
481483
chain_2=chain_2,
482484
),
483485
]
484-
crop_fn = random.choices(crop_fns, crop_fn_weights)[0]
486+
crop_fn = random.choices(crop_fns, crop_fn_weights)[0] # nosec
485487
return crop_fn()
486488

487489

488490
@typecheck
489491
def create_contiguous_crop_masks(
490492
chain_ids_and_lengths: List[Tuple[str, int]], n_res: int
491493
) -> List[np.ndarray]:
492-
"""
493-
Create contiguous crop masks for each given chain.
494+
"""Create contiguous crop masks for each given chain.
495+
494496
Implements Algorithm 1 from the AlphaFold-Multimer paper.
495497
"""
496498
m_ks = []
@@ -502,9 +504,9 @@ def create_contiguous_crop_masks(
502504
crop_size_max = min(n_res - n_added, n_k)
503505
# NOTE: `max(0, n_remaining)` was analytically added to prevent invalid crop sizes.
504506
crop_size_min = min(n_k, max(0, n_res - (n_added + max(0, n_remaining))))
505-
crop_size = random.randrange(crop_size_min, crop_size_max + 1)
507+
crop_size = random.randrange(crop_size_min, crop_size_max + 1) # nosec
506508
n_added += crop_size
507-
crop_start = random.randrange(0, n_k - crop_size + 1)
509+
crop_start = random.randrange(0, n_k - crop_size + 1) # nosec
508510
m_k = np.zeros(n_k, dtype=bool)
509511
keep = np.arange(crop_start, crop_start + crop_size)
510512
m_k[keep] = True
@@ -519,8 +521,8 @@ def create_spatial_crop_masks(
519521
reference_token_center_atom_index: int,
520522
n_res: int,
521523
) -> List[np.ndarray]:
522-
"""
523-
Create spatial crop masks for each given chain.
524+
"""Create spatial crop masks for each given chain.
525+
524526
Implements Algorithm 2 from the AlphaFold-Multimer paper.
525527
"""
526528
# calculate distances with small uniquifying values to break ties
@@ -715,9 +717,7 @@ def _from_mmcif_object(
715717
residize_polymer = is_polymer_residue and not (
716718
is_modified_polymer_residue and atomize_modified_polymer_residues
717719
)
718-
residize_non_polymer = (
719-
not is_polymer_residue and not atomize_ligand_residues
720-
)
720+
residize_non_polymer = not is_polymer_residue and not atomize_ligand_residues
721721
if residize_polymer or residize_non_polymer:
722722
pos = np.zeros((residue_constants.atom_type_num, 3))
723723
mask = np.zeros((residue_constants.atom_type_num,))
@@ -898,6 +898,14 @@ def from_mmcif_string(
898898
:param file_id: The file ID (usually the PDB ID) to be used in the mmCIF.
899899
:param chain_ids: If chain_ids are specified (e.g. A), then only these chains are parsed.
900900
Otherwise all chains are parsed.
901+
:param atomize_ligand_residues: If True, then the atoms of ligand
902+
residues are treated as "pseudoresidues". This is useful for
903+
representing ligand residues as a collection of atoms rather
904+
than as a single residue.
905+
:param atomize_modified_polymer_residues: If True, then the atoms of modified
906+
polymer residues are treated as "pseudoresidues". This is useful for
907+
representing modified polymer residues as a collection of (e.g., ligand)
908+
atoms rather than as a single residue.
901909
902910
:return: A new `Biomolecule` parsed from the mmCIF contents.
903911
@@ -946,7 +954,6 @@ def remove_metadata_fields_by_prefixes(
946954
947955
:param metadata_dict: The metadata default dictionary from which to remove metadata fields.
948956
:param field_prefixes: A list of prefixes to remove from the metadata default dictionary.
949-
950957
:return: A metadata dictionary with the specified metadata fields removed.
951958
"""
952959
return {

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def make_mmcif_features(
327327
file_id=file_id,
328328
)
329329
mmcif_feats, assembly = make_mmcif_features(mmcif_object)
330-
cropped_assembly = assembly.crop(
330+
cropped_assembly, _, _ = assembly.crop(
331331
contiguous_weight=0.2,
332332
spatial_weight=0.4,
333333
spatial_interface_weight=0.4,

0 commit comments

Comments
 (0)