Skip to content

Commit 77dd8f4

Browse files
authored
Add (random) contiguous sequence cropping for Biomolecule objects (#101)
* Update data_pipeline.py * Update biomolecule.py
1 parent 154ac2a commit 77dd8f4

File tree

2 files changed

+126
-11
lines changed

2 files changed

+126
-11
lines changed

alphafold3_pytorch/common/biomolecule.py

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
import io
77
import random
8+
from functools import partial
89
from types import ModuleType
910
from typing import Any, Dict, List, Optional, Set, Tuple
1011

@@ -235,40 +236,150 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
235236

236237
def crop_chains_with_masks(
237238
self, chain_ids_and_lengths: List[Tuple[str, int]], crop_masks: List[np.ndarray]
238-
):
239+
) -> "Biomolecule":
239240
"""
240241
Crop the chains and metadata within a Biomolecule
241242
to only include the specified chain residues.
242243
"""
243244
assert len(chain_ids_and_lengths) == len(
244245
crop_masks
245246
), "The number of chains and crop masks must be equal."
246-
raise NotImplementedError("Chain cropping is not yet implemented.")
247+
assert not all(
248+
crop_mask.all() for crop_mask in crop_masks
249+
), "Not all tokens can be cropped out of a Biomolecule."
250+
251+
# collect metadata for each chain
252+
253+
unique_chain_ids = np.unique(self.chain_id)
254+
chains_to_remove = {
255+
chain_id_and_length[0]
256+
for chain_id_and_length, crop_mask in zip(chain_ids_and_lengths, crop_masks)
257+
if not crop_mask.any()
258+
}
259+
subset_chain_id_mapping = {
260+
chain_id: n
261+
for n, chain_id in enumerate(unique_chain_ids)
262+
if chain_id not in chains_to_remove
263+
}
264+
subset_chain_index_mapping = {
265+
n: chain_id for chain_id, n in subset_chain_id_mapping.items()
266+
}
267+
chain_id_to_index = {
268+
chain_id_and_length[0]: i
269+
for i, chain_id_and_length in enumerate(chain_ids_and_lengths)
270+
}
271+
272+
# create metadata for cropping
273+
274+
chain_mask = np.concatenate(
275+
[crop_masks[chain_id_to_index[c_id]] for c_id in unique_chain_ids]
276+
)
277+
chain_residue_index = np.array(
278+
list(zip(self.chain_index[chain_mask], self.residue_index[chain_mask]))
279+
)
280+
# NOTE: We must only consider unique chain-residue index pairs here,
281+
# as otherwise we might count each ligand heavy atom as a residue in this mapping
282+
subset_chain_residue_mapping = set(map(tuple, chain_residue_index))
283+
284+
# manually subset certain Biomolecule metadata
285+
286+
entity_to_chain = {
287+
entity_id: [
288+
chain_id for chain_id in chain_ids if chain_id in subset_chain_index_mapping
289+
]
290+
for entity_id, chain_ids in self.entity_to_chain.items()
291+
if any(chain_id in subset_chain_index_mapping for chain_id in chain_ids)
292+
}
293+
mmcif_to_author_chain = {
294+
mmcif_chain: author_chain_id
295+
for mmcif_chain, author_chain_id in self.mmcif_to_author_chain.items()
296+
if author_chain_id in subset_chain_index_mapping
297+
}
298+
299+
# construct a new cropped Biomolecule
247300

248-
def contiguous_crop(self, n_res: int) -> "Biomolecule":
301+
return Biomolecule(
302+
atom_positions=self.atom_positions[chain_mask],
303+
atom_name=self.atom_name[chain_mask],
304+
restype=self.restype[chain_mask],
305+
atom_mask=self.atom_mask[chain_mask],
306+
residue_index=self.residue_index[chain_mask],
307+
chain_index=self.chain_index[chain_mask],
308+
chain_id=self.chain_id[chain_mask],
309+
b_factors=self.b_factors[chain_mask],
310+
chemid=self.chemid[chain_mask],
311+
chemtype=self.chemtype[chain_mask],
312+
bonds=[
313+
bond
314+
for bond in self.bonds
315+
if bond.ptnr1_auth_asym_id not in chains_to_remove
316+
and bond.ptnr2_auth_asym_id not in chains_to_remove
317+
],
318+
unique_res_atom_names=[
319+
unique_res_atom_names
320+
for unique_res_atom_names in self.unique_res_atom_names
321+
if unique_res_atom_names[1] not in chains_to_remove
322+
and (subset_chain_id_mapping[unique_res_atom_names[1]], unique_res_atom_names[2])
323+
in subset_chain_residue_mapping
324+
],
325+
author_cri_to_new_cri={
326+
author_cri: new_cri
327+
for author_cri, new_cri in self.author_cri_to_new_cri.items()
328+
if new_cri[0] in subset_chain_index_mapping
329+
},
330+
chem_comp_table=self.chem_comp_table,
331+
entity_to_chain=entity_to_chain,
332+
mmcif_to_author_chain=mmcif_to_author_chain,
333+
mmcif_metadata=self.mmcif_metadata,
334+
)
335+
336+
def contiguous_crop(self, n_res: int = 384) -> "Biomolecule":
249337
"""
250338
Crop a Biomolecule to only include contiguous
251339
polymer residues and/or ligand atoms for each chain.
252340
"""
253341
chain_ids_and_lengths = list(collections.Counter(self.chain_id).items())
254342
random.shuffle(chain_ids_and_lengths)
255343
crop_masks = create_contiguous_crop_masks(chain_ids_and_lengths, n_res)
256-
self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks)
344+
return self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks)
257345

258-
def spatial_crop(self) -> "Biomolecule":
346+
def spatial_crop(
347+
self, chain_1: Optional[str] = None, chain_2: Optional[str] = None
348+
) -> "Biomolecule":
259349
"""
260350
Crop a Biomolecule to only include polymer residues and ligand atoms
261351
near a (random) reference atom within a sampled chain/interface.
262352
"""
263353
raise NotImplementedError("Spatial cropping is not yet implemented.")
264354

265-
def spatial_interface_crop(self) -> "Biomolecule":
355+
def spatial_interface_crop(
356+
self, chain_1: Optional[str] = None, chain_2: Optional[str] = None
357+
) -> "Biomolecule":
266358
"""
267359
Crop a Biomolecule to only include contiguous polymer residues
268360
and/or ligand atoms for each chain.
269361
"""
270362
raise NotImplementedError("Spatial interface cropping is not yet implemented.")
271363

364+
def crop(
365+
self,
366+
contiguous_weight: float = 0.2,
367+
spatial_weight: float = 0.4,
368+
spatial_interface_weight: float = 0.4,
369+
n_res: int = 384,
370+
chain_1: Optional[str] = None,
371+
chain_2: Optional[str] = None,
372+
) -> "Biomolecule":
373+
"""Crop a Biomolecule using a randomly-sampled cropping function."""
374+
crop_fn_weights = [contiguous_weight, spatial_weight, spatial_interface_weight]
375+
crop_fns = [
376+
partial(self.contiguous_crop, n_res=n_res),
377+
partial(self.spatial_crop, chain_1=chain_1, chain_2=chain_2),
378+
partial(self.spatial_interface_crop, chain_1=chain_1, chain_2=chain_2),
379+
]
380+
crop_fn = random.choices(crop_fns, crop_fn_weights)[0]
381+
return crop_fn()
382+
272383

273384
@typecheck
274385
def create_contiguous_crop_masks(

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,19 @@ def make_mmcif_features(
150150
file_id=file_id,
151151
)
152152
mmcif_feats, assembly = make_mmcif_features(mmcif_object)
153-
# cropped_assembly = assembly.contiguous_crop(384)
153+
cropped_assembly = assembly.crop(
154+
contiguous_weight=1.0,
155+
spatial_weight=0.0,
156+
spatial_interface_weight=0.0,
157+
)
154158
mmcif_string = to_mmcif(
155-
assembly,
156-
# cropped_assembly,
159+
# assembly,
160+
cropped_assembly,
157161
file_id=file_id,
158162
gapless_poly_seq=True,
159163
insert_alphafold_mmcif_metadata=False,
160-
unique_res_atom_names=assembly.unique_res_atom_names,
161-
# unique_res_atom_names=cropped_assembly.unique_res_atom_names,
164+
# unique_res_atom_names=assembly.unique_res_atom_names,
165+
unique_res_atom_names=cropped_assembly.unique_res_atom_names,
162166
)
163167
with open(os.path.basename(filepath).replace(".cif", "_reconstructed.cif"), "w") as f:
164168
f.write(mmcif_string)

0 commit comments

Comments
 (0)