|
4 | 4 | import dataclasses |
5 | 5 | import functools |
6 | 6 | import io |
| 7 | +import random |
7 | 8 | from types import ModuleType |
8 | 9 | from typing import Any, Dict, List, Optional, Set, Tuple |
9 | 10 |
|
@@ -232,6 +233,69 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule": |
232 | 233 | mmcif_metadata=self.mmcif_metadata, |
233 | 234 | ) |
234 | 235 |
|
| 236 | + def crop_chains_with_masks( |
| 237 | + self, chain_ids_and_lengths: List[Tuple[str, int]], crop_masks: List[np.ndarray] |
| 238 | + ): |
| 239 | + """ |
| 240 | + Crop the chains and metadata within a Biomolecule |
| 241 | + to only include the specified chain residues. |
| 242 | + """ |
| 243 | + assert len(chain_ids_and_lengths) == len( |
| 244 | + crop_masks |
| 245 | + ), "The number of chains and crop masks must be equal." |
| 246 | + raise NotImplementedError("Chain cropping is not yet implemented.") |
| 247 | + |
| 248 | + def contiguous_crop(self, n_res: int) -> "Biomolecule": |
| 249 | + """ |
| 250 | + Crop a Biomolecule to only include contiguous |
| 251 | + polymer residues and/or ligand atoms for each chain. |
| 252 | + """ |
| 253 | + chain_ids_and_lengths = list(collections.Counter(self.chain_id).items()) |
| 254 | + random.shuffle(chain_ids_and_lengths) |
| 255 | + crop_masks = create_contiguous_crop_masks(chain_ids_and_lengths, n_res) |
| 256 | + self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks) |
| 257 | + |
| 258 | + def spatial_crop(self) -> "Biomolecule": |
| 259 | + """ |
| 260 | + Crop a Biomolecule to only include polymer residues and ligand atoms |
| 261 | + near a (random) reference atom within a sampled chain/interface. |
| 262 | + """ |
| 263 | + raise NotImplementedError("Spatial cropping is not yet implemented.") |
| 264 | + |
| 265 | + def spatial_interface_crop(self) -> "Biomolecule": |
| 266 | + """ |
| 267 | + Crop a Biomolecule to only include contiguous polymer residues |
| 268 | + and/or ligand atoms for each chain. |
| 269 | + """ |
| 270 | + raise NotImplementedError("Spatial interface cropping is not yet implemented.") |
| 271 | + |
| 272 | + |
| 273 | +@typecheck |
| 274 | +def create_contiguous_crop_masks( |
| 275 | + chain_ids_and_lengths: List[Tuple[str, int]], n_res: int |
| 276 | +) -> List[np.ndarray]: |
| 277 | + """ |
| 278 | + Create contiguous crop masks for each given chain. |
| 279 | + Implements Algorithm 1 from the AlphaFold-Multimer paper. |
| 280 | + """ |
| 281 | + m_ks = [] |
| 282 | + n_added = 0 |
| 283 | + n_remaining = n_res |
| 284 | + for chain_id_and_length in chain_ids_and_lengths: |
| 285 | + n_k = chain_id_and_length[1] |
| 286 | + n_remaining -= n_k |
| 287 | + crop_size_max = min(n_res - n_added, n_k) |
| 288 | + # NOTE: `max(0, n_remaining)` was analytically added to prevent invalid crop sizes. |
| 289 | + crop_size_min = min(n_k, max(0, n_res - (n_added + max(0, n_remaining)))) |
| 290 | + crop_size = random.randrange(crop_size_min, crop_size_max + 1) |
| 291 | + n_added += crop_size |
| 292 | + crop_start = random.randrange(0, n_k - crop_size + 1) |
| 293 | + m_k = np.zeros(n_k, dtype=bool) |
| 294 | + keep = np.arange(crop_start, crop_start + crop_size) |
| 295 | + m_k[keep] = True |
| 296 | + m_ks.append(m_k) |
| 297 | + return m_ks |
| 298 | + |
235 | 299 |
|
236 | 300 | @typecheck |
237 | 301 | def get_residue_constants( |
@@ -307,7 +371,8 @@ def get_unique_res_atom_names( |
307 | 371 |
|
308 | 372 | @typecheck |
309 | 373 | def _from_mmcif_object( |
310 | | - mmcif_object: mmcif_parsing.MmcifObject, chain_ids: Optional[Set[str]] = None, |
| 374 | + mmcif_object: mmcif_parsing.MmcifObject, |
| 375 | + chain_ids: Optional[Set[str]] = None, |
311 | 376 | ) -> Biomolecule: |
312 | 377 | """Takes a Biopython structure/model mmCIF object and creates a `Biomolecule` instance. |
313 | 378 |
|
@@ -543,7 +608,9 @@ def _from_mmcif_object( |
543 | 608 |
|
544 | 609 |
|
545 | 610 | @typecheck |
546 | | -def from_mmcif_string(mmcif_str: str, file_id: str, chain_ids: Optional[Set[str]] = None) -> Biomolecule: |
| 611 | +def from_mmcif_string( |
| 612 | + mmcif_str: str, file_id: str, chain_ids: Optional[Set[str]] = None |
| 613 | +) -> Biomolecule: |
547 | 614 | """Takes a mmCIF string and constructs a `Biomolecule` object. |
548 | 615 |
|
549 | 616 | WARNING: All non-standard residue types will be converted into UNK. All |
|
0 commit comments