@@ -142,13 +142,18 @@ def __add__(self, other: "Biomolecule") -> "Biomolecule":
142142 chemtype = np .concatenate ([self .chemtype , other .chemtype ], axis = 0 ),
143143 bonds = list (dict .fromkeys (self .bonds + other .bonds )),
144144 unique_res_atom_names = self .unique_res_atom_names + other .unique_res_atom_names ,
145- author_cri_to_new_cri = {** self .author_cri_to_new_cri , ** other .author_cri_to_new_cri },
145+ author_cri_to_new_cri = {
146+ ** self .author_cri_to_new_cri ,
147+ ** other .author_cri_to_new_cri ,
148+ },
146149 chem_comp_table = self .chem_comp_table .union (other .chem_comp_table ),
147150 entity_to_chain = deep_merge_dicts (
148151 self .entity_to_chain , other .entity_to_chain , value_op = "union"
149152 ),
150153 mmcif_to_author_chain = deep_merge_dicts (
151- self .mmcif_to_author_chain , other .mmcif_to_author_chain , value_op = "union"
154+ self .mmcif_to_author_chain ,
155+ other .mmcif_to_author_chain ,
156+ value_op = "union" ,
152157 ),
153158 mmcif_metadata = {** self .mmcif_metadata , ** other .mmcif_metadata },
154159 )
@@ -237,10 +242,8 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
237242 def crop_chains_with_masks (
238243 self , chain_ids_and_lengths : List [Tuple [str , int ]], crop_masks : List [np .ndarray ]
239244 ) -> "Biomolecule" :
240- """
241- Crop the chains and metadata within a Biomolecule
242- to only include the specified chain residues.
243- """
245+ """Crop the chains and metadata within a Biomolecule to only include the specified chain
246+ residues."""
244247 assert len (chain_ids_and_lengths ) == len (
245248 crop_masks
246249 ), "The number of chains and crop masks must be equal."
@@ -320,7 +323,10 @@ def crop_chains_with_masks(
320323 unique_res_atom_names
321324 for unique_res_atom_names in self .unique_res_atom_names
322325 if unique_res_atom_names [1 ] not in chains_to_remove
323- and (subset_chain_id_mapping [unique_res_atom_names [1 ]], unique_res_atom_names [2 ])
326+ and (
327+ subset_chain_id_mapping [unique_res_atom_names [1 ]],
328+ unique_res_atom_names [2 ],
329+ )
324330 in subset_chain_residue_mapping
325331 ],
326332 author_cri_to_new_cri = {
@@ -334,15 +340,13 @@ def crop_chains_with_masks(
334340 mmcif_metadata = self .mmcif_metadata ,
335341 )
336342
337- def contiguous_crop (self , n_res : int = 384 ) -> "Biomolecule" :
338- """
339- Crop a Biomolecule to only include contiguous
340- polymer residues and/or ligand atoms for each chain.
341- """
343+ def contiguous_crop (self , n_res : int = 384 ) -> Tuple ["Biomolecule" , List [Tuple [str , int ]], List [np .ndarray ]]:
344+ """Crop a Biomolecule to only include contiguous polymer residues and/or ligand atoms for
345+ each chain."""
342346 chain_ids_and_lengths = list (collections .Counter (self .chain_id ).items ())
343347 random .shuffle (chain_ids_and_lengths )
344348 crop_masks = create_contiguous_crop_masks (chain_ids_and_lengths , n_res )
345- return self .crop_chains_with_masks (chain_ids_and_lengths , crop_masks )
349+ return self .crop_chains_with_masks (chain_ids_and_lengths , crop_masks ), chain_ids_and_lengths , crop_masks
346350
347351 def spatial_crop (
348352 self ,
@@ -351,11 +355,9 @@ def spatial_crop(
351355 chain_1 : Optional [str ] = None ,
352356 chain_2 : Optional [str ] = None ,
353357 interface_distance_threshold : float = 15.0 ,
354- ) -> "Biomolecule" :
355- """
356- Crop a Biomolecule to only include polymer residues and ligand atoms
357- near a (random) reference atom within a sampled chain/interface.
358- """
358+ ) -> Tuple ["Biomolecule" , List [Tuple [str , int ]], List [np .ndarray ]]:
359+ """Crop a Biomolecule to only include polymer residues and ligand atoms near a (random)
360+ reference atom within a sampled chain/interface."""
359361
360362 # curate a list of candidate token center atoms from which to sample a reference atom
361363
@@ -430,7 +432,7 @@ def spatial_crop(
430432
431433 # sample a reference atom for spatial cropping
432434
433- reference_atom_index = random .choice (token_center_atom_indices ).item ()
435+ reference_atom_index = random .choice (token_center_atom_indices ).item () # nosec
434436
435437 # perform spatial cropping according to reference atom proximity
436438
@@ -441,7 +443,7 @@ def spatial_crop(
441443 reference_atom_index ,
442444 n_res ,
443445 )
444- return self .crop_chains_with_masks (chain_ids_and_lengths , crop_masks )
446+ return self .crop_chains_with_masks (chain_ids_and_lengths , crop_masks ), chain_ids_and_lengths , crop_masks
445447
446448 def crop (
447449 self ,
@@ -451,7 +453,7 @@ def crop(
451453 n_res : int = 384 ,
452454 chain_1 : str | None = None ,
453455 chain_2 : str | None = None ,
454- ) -> "Biomolecule" :
456+ ) -> Tuple [ "Biomolecule" , List [ Tuple [ str , int ]], List [ np . ndarray ]] :
455457 """Crop a Biomolecule using a randomly-sampled cropping function."""
456458 n_res = min (n_res , len (self .atom_mask ))
457459 if exists (chain_1 ) and exists (chain_2 ):
@@ -481,16 +483,16 @@ def crop(
481483 chain_2 = chain_2 ,
482484 ),
483485 ]
484- crop_fn = random .choices (crop_fns , crop_fn_weights )[0 ]
486+ crop_fn = random .choices (crop_fns , crop_fn_weights )[0 ] # nosec
485487 return crop_fn ()
486488
487489
488490@typecheck
489491def create_contiguous_crop_masks (
490492 chain_ids_and_lengths : List [Tuple [str , int ]], n_res : int
491493) -> List [np .ndarray ]:
492- """
493- Create contiguous crop masks for each given chain.
494+ """Create contiguous crop masks for each given chain.
495+
494496 Implements Algorithm 1 from the AlphaFold-Multimer paper.
495497 """
496498 m_ks = []
@@ -502,9 +504,9 @@ def create_contiguous_crop_masks(
502504 crop_size_max = min (n_res - n_added , n_k )
503505 # NOTE: `max(0, n_remaining)` was analytically added to prevent invalid crop sizes.
504506 crop_size_min = min (n_k , max (0 , n_res - (n_added + max (0 , n_remaining ))))
505- crop_size = random .randrange (crop_size_min , crop_size_max + 1 )
507+ crop_size = random .randrange (crop_size_min , crop_size_max + 1 ) # nosec
506508 n_added += crop_size
507- crop_start = random .randrange (0 , n_k - crop_size + 1 )
509+ crop_start = random .randrange (0 , n_k - crop_size + 1 ) # nosec
508510 m_k = np .zeros (n_k , dtype = bool )
509511 keep = np .arange (crop_start , crop_start + crop_size )
510512 m_k [keep ] = True
@@ -519,8 +521,8 @@ def create_spatial_crop_masks(
519521 reference_token_center_atom_index : int ,
520522 n_res : int ,
521523) -> List [np .ndarray ]:
522- """
523- Create spatial crop masks for each given chain.
524+ """Create spatial crop masks for each given chain.
525+
524526 Implements Algorithm 2 from the AlphaFold-Multimer paper.
525527 """
526528 # calculate distances with small uniquifying values to break ties
@@ -715,9 +717,7 @@ def _from_mmcif_object(
715717 residize_polymer = is_polymer_residue and not (
716718 is_modified_polymer_residue and atomize_modified_polymer_residues
717719 )
718- residize_non_polymer = (
719- not is_polymer_residue and not atomize_ligand_residues
720- )
720+ residize_non_polymer = not is_polymer_residue and not atomize_ligand_residues
721721 if residize_polymer or residize_non_polymer :
722722 pos = np .zeros ((residue_constants .atom_type_num , 3 ))
723723 mask = np .zeros ((residue_constants .atom_type_num ,))
@@ -898,6 +898,14 @@ def from_mmcif_string(
898898 :param file_id: The file ID (usually the PDB ID) to be used in the mmCIF.
899899 :param chain_ids: If chain_ids are specified (e.g. A), then only these chains are parsed.
900900 Otherwise all chains are parsed.
901+ :param atomize_ligand_residues: If True, then the atoms of ligand
902+ residues are treated as "pseudoresidues". This is useful for
903+ representing ligand residues as a collection of atoms rather
904+ than as a single residue.
905+ :param atomize_modified_polymer_residues: If True, then the atoms of modified
906+ polymer residues are treated as "pseudoresidues". This is useful for
907+ representing modified polymer residues as a collection of (e.g., ligand)
908+ atoms rather than as a single residue.
901909
902910 :return: A new `Biomolecule` parsed from the mmCIF contents.
903911
@@ -946,7 +954,6 @@ def remove_metadata_fields_by_prefixes(
946954
947955 :param metadata_dict: The metadata default dictionary from which to remove metadata fields.
948956 :param field_prefixes: A list of prefixes to remove from the metadata default dictionary.
949-
950957 :return: A metadata dictionary with the specified metadata fields removed.
951958 """
952959 return {
0 commit comments