|
5 | 5 | import functools |
6 | 6 | import io |
7 | 7 | import random |
| 8 | +from functools import partial |
8 | 9 | from types import ModuleType |
9 | 10 | from typing import Any, Dict, List, Optional, Set, Tuple |
10 | 11 |
|
@@ -235,40 +236,150 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule": |
235 | 236 |
|
236 | 237 | def crop_chains_with_masks( |
237 | 238 | self, chain_ids_and_lengths: List[Tuple[str, int]], crop_masks: List[np.ndarray] |
238 | | - ): |
| 239 | + ) -> "Biomolecule": |
239 | 240 | """ |
240 | 241 | Crop the chains and metadata within a Biomolecule |
241 | 242 | to only include the specified chain residues. |
242 | 243 | """ |
243 | 244 | assert len(chain_ids_and_lengths) == len( |
244 | 245 | crop_masks |
245 | 246 | ), "The number of chains and crop masks must be equal." |
246 | | - raise NotImplementedError("Chain cropping is not yet implemented.") |
| 247 | + assert not all( |
| 248 | + crop_mask.all() for crop_mask in crop_masks |
| 249 | + ), "Not all tokens can be cropped out of a Biomolecule." |
| 250 | + |
| 251 | + # collect metadata for each chain |
| 252 | + |
| 253 | + unique_chain_ids = np.unique(self.chain_id) |
| 254 | + chains_to_remove = { |
| 255 | + chain_id_and_length[0] |
| 256 | + for chain_id_and_length, crop_mask in zip(chain_ids_and_lengths, crop_masks) |
| 257 | + if not crop_mask.any() |
| 258 | + } |
| 259 | + subset_chain_id_mapping = { |
| 260 | + chain_id: n |
| 261 | + for n, chain_id in enumerate(unique_chain_ids) |
| 262 | + if chain_id not in chains_to_remove |
| 263 | + } |
| 264 | + subset_chain_index_mapping = { |
| 265 | + n: chain_id for chain_id, n in subset_chain_id_mapping.items() |
| 266 | + } |
| 267 | + chain_id_to_index = { |
| 268 | + chain_id_and_length[0]: i |
| 269 | + for i, chain_id_and_length in enumerate(chain_ids_and_lengths) |
| 270 | + } |
| 271 | + |
| 272 | + # create metadata for cropping |
| 273 | + |
| 274 | + chain_mask = np.concatenate( |
| 275 | + [crop_masks[chain_id_to_index[c_id]] for c_id in unique_chain_ids] |
| 276 | + ) |
| 277 | + chain_residue_index = np.array( |
| 278 | + list(zip(self.chain_index[chain_mask], self.residue_index[chain_mask])) |
| 279 | + ) |
| 280 | + # NOTE: We must only consider unique chain-residue index pairs here, |
| 281 | + # as otherwise we might count each ligand heavy atom as a residue in this mapping |
| 282 | + subset_chain_residue_mapping = set(map(tuple, chain_residue_index)) |
| 283 | + |
| 284 | + # manually subset certain Biomolecule metadata |
| 285 | + |
| 286 | + entity_to_chain = { |
| 287 | + entity_id: [ |
| 288 | + chain_id for chain_id in chain_ids if chain_id in subset_chain_index_mapping |
| 289 | + ] |
| 290 | + for entity_id, chain_ids in self.entity_to_chain.items() |
| 291 | + if any(chain_id in subset_chain_index_mapping for chain_id in chain_ids) |
| 292 | + } |
| 293 | + mmcif_to_author_chain = { |
| 294 | + mmcif_chain: author_chain_id |
| 295 | + for mmcif_chain, author_chain_id in self.mmcif_to_author_chain.items() |
| 296 | + if author_chain_id in subset_chain_index_mapping |
| 297 | + } |
| 298 | + |
| 299 | + # construct a new cropped Biomolecule |
247 | 300 |
|
248 | | - def contiguous_crop(self, n_res: int) -> "Biomolecule": |
| 301 | + return Biomolecule( |
| 302 | + atom_positions=self.atom_positions[chain_mask], |
| 303 | + atom_name=self.atom_name[chain_mask], |
| 304 | + restype=self.restype[chain_mask], |
| 305 | + atom_mask=self.atom_mask[chain_mask], |
| 306 | + residue_index=self.residue_index[chain_mask], |
| 307 | + chain_index=self.chain_index[chain_mask], |
| 308 | + chain_id=self.chain_id[chain_mask], |
| 309 | + b_factors=self.b_factors[chain_mask], |
| 310 | + chemid=self.chemid[chain_mask], |
| 311 | + chemtype=self.chemtype[chain_mask], |
| 312 | + bonds=[ |
| 313 | + bond |
| 314 | + for bond in self.bonds |
| 315 | + if bond.ptnr1_auth_asym_id not in chains_to_remove |
| 316 | + and bond.ptnr2_auth_asym_id not in chains_to_remove |
| 317 | + ], |
| 318 | + unique_res_atom_names=[ |
| 319 | + unique_res_atom_names |
| 320 | + for unique_res_atom_names in self.unique_res_atom_names |
| 321 | + if unique_res_atom_names[1] not in chains_to_remove |
| 322 | + and (subset_chain_id_mapping[unique_res_atom_names[1]], unique_res_atom_names[2]) |
| 323 | + in subset_chain_residue_mapping |
| 324 | + ], |
| 325 | + author_cri_to_new_cri={ |
| 326 | + author_cri: new_cri |
| 327 | + for author_cri, new_cri in self.author_cri_to_new_cri.items() |
| 328 | + if new_cri[0] in subset_chain_index_mapping |
| 329 | + }, |
| 330 | + chem_comp_table=self.chem_comp_table, |
| 331 | + entity_to_chain=entity_to_chain, |
| 332 | + mmcif_to_author_chain=mmcif_to_author_chain, |
| 333 | + mmcif_metadata=self.mmcif_metadata, |
| 334 | + ) |
| 335 | + |
| 336 | + def contiguous_crop(self, n_res: int = 384) -> "Biomolecule": |
249 | 337 | """ |
250 | 338 | Crop a Biomolecule to only include contiguous |
251 | 339 | polymer residues and/or ligand atoms for each chain. |
252 | 340 | """ |
253 | 341 | chain_ids_and_lengths = list(collections.Counter(self.chain_id).items()) |
254 | 342 | random.shuffle(chain_ids_and_lengths) |
255 | 343 | crop_masks = create_contiguous_crop_masks(chain_ids_and_lengths, n_res) |
256 | | - self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks) |
| 344 | + return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks) |
257 | 345 |
|
258 | | - def spatial_crop(self) -> "Biomolecule": |
| 346 | + def spatial_crop( |
| 347 | + self, chain_1: Optional[str] = None, chain_2: Optional[str] = None |
| 348 | + ) -> "Biomolecule": |
259 | 349 | """ |
260 | 350 | Crop a Biomolecule to only include polymer residues and ligand atoms |
261 | 351 | near a (random) reference atom within a sampled chain/interface. |
262 | 352 | """ |
263 | 353 | raise NotImplementedError("Spatial cropping is not yet implemented.") |
264 | 354 |
|
265 | | - def spatial_interface_crop(self) -> "Biomolecule": |
| 355 | + def spatial_interface_crop( |
| 356 | + self, chain_1: Optional[str] = None, chain_2: Optional[str] = None |
| 357 | + ) -> "Biomolecule": |
266 | 358 | """ |
267 | 359 | Crop a Biomolecule to only include contiguous polymer residues |
268 | 360 | and/or ligand atoms for each chain. |
269 | 361 | """ |
270 | 362 | raise NotImplementedError("Spatial interface cropping is not yet implemented.") |
271 | 363 |
|
| 364 | + def crop( |
| 365 | + self, |
| 366 | + contiguous_weight: float = 0.2, |
| 367 | + spatial_weight: float = 0.4, |
| 368 | + spatial_interface_weight: float = 0.4, |
| 369 | + n_res: int = 384, |
| 370 | + chain_1: Optional[str] = None, |
| 371 | + chain_2: Optional[str] = None, |
| 372 | + ) -> "Biomolecule": |
| 373 | + """Crop a Biomolecule using a randomly-sampled cropping function.""" |
| 374 | + crop_fn_weights = [contiguous_weight, spatial_weight, spatial_interface_weight] |
| 375 | + crop_fns = [ |
| 376 | + partial(self.contiguous_crop, n_res=n_res), |
| 377 | + partial(self.spatial_crop, chain_1=chain_1, chain_2=chain_2), |
| 378 | + partial(self.spatial_interface_crop, chain_1=chain_1, chain_2=chain_2), |
| 379 | + ] |
| 380 | + crop_fn = random.choices(crop_fns, crop_fn_weights)[0] |
| 381 | + return crop_fn() |
| 382 | + |
272 | 383 |
|
273 | 384 | @typecheck |
274 | 385 | def create_contiguous_crop_masks( |
|
0 commit comments