Skip to content

Commit f14b816

Browse files
authored
Create outline of how to do cropping with Biomolecule objects (#99)
* Update biomolecule.py * Update data_pipeline.py
1 parent 92e052a commit f14b816

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

alphafold3_pytorch/common/biomolecule.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dataclasses
55
import functools
66
import io
7+
import random
78
from types import ModuleType
89
from typing import Any, Dict, List, Optional, Set, Tuple
910

@@ -232,6 +233,69 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
232233
mmcif_metadata=self.mmcif_metadata,
233234
)
234235

236+
def crop_chains_with_masks(
237+
self, chain_ids_and_lengths: List[Tuple[str, int]], crop_masks: List[np.ndarray]
238+
):
239+
"""
240+
Crop the chains and metadata within a Biomolecule
241+
to only include the specified chain residues.
242+
"""
243+
assert len(chain_ids_and_lengths) == len(
244+
crop_masks
245+
), "The number of chains and crop masks must be equal."
246+
raise NotImplementedError("Chain cropping is not yet implemented.")
247+
248+
def contiguous_crop(self, n_res: int) -> "Biomolecule":
249+
"""
250+
Crop a Biomolecule to only include contiguous
251+
polymer residues and/or ligand atoms for each chain.
252+
"""
253+
chain_ids_and_lengths = list(collections.Counter(self.chain_id).items())
254+
random.shuffle(chain_ids_and_lengths)
255+
crop_masks = create_contiguous_crop_masks(chain_ids_and_lengths, n_res)
256+
self.crop_chains_with_masks(chain_ids_and_lengths, crop_masks)
257+
258+
def spatial_crop(self) -> "Biomolecule":
259+
"""
260+
Crop a Biomolecule to only include polymer residues and ligand atoms
261+
near a (random) reference atom within a sampled chain/interface.
262+
"""
263+
raise NotImplementedError("Spatial cropping is not yet implemented.")
264+
265+
def spatial_interface_crop(self) -> "Biomolecule":
266+
"""
267+
Crop a Biomolecule to only include contiguous polymer residues
268+
and/or ligand atoms for each chain.
269+
"""
270+
raise NotImplementedError("Spatial interface cropping is not yet implemented.")
271+
272+
273+
@typecheck
274+
def create_contiguous_crop_masks(
275+
chain_ids_and_lengths: List[Tuple[str, int]], n_res: int
276+
) -> List[np.ndarray]:
277+
"""
278+
Create contiguous crop masks for each given chain.
279+
Implements Algorithm 1 from the AlphaFold-Multimer paper.
280+
"""
281+
m_ks = []
282+
n_added = 0
283+
n_remaining = n_res
284+
for chain_id_and_length in chain_ids_and_lengths:
285+
n_k = chain_id_and_length[1]
286+
n_remaining -= n_k
287+
crop_size_max = min(n_res - n_added, n_k)
288+
# NOTE: `max(0, n_remaining)` was analytically added to prevent invalid crop sizes.
289+
crop_size_min = min(n_k, max(0, n_res - (n_added + max(0, n_remaining))))
290+
crop_size = random.randrange(crop_size_min, crop_size_max + 1)
291+
n_added += crop_size
292+
crop_start = random.randrange(0, n_k - crop_size + 1)
293+
m_k = np.zeros(n_k, dtype=bool)
294+
keep = np.arange(crop_start, crop_start + crop_size)
295+
m_k[keep] = True
296+
m_ks.append(m_k)
297+
return m_ks
298+
235299

236300
@typecheck
237301
def get_residue_constants(
@@ -307,7 +371,8 @@ def get_unique_res_atom_names(
307371

308372
@typecheck
309373
def _from_mmcif_object(
310-
mmcif_object: mmcif_parsing.MmcifObject, chain_ids: Optional[Set[str]] = None,
374+
mmcif_object: mmcif_parsing.MmcifObject,
375+
chain_ids: Optional[Set[str]] = None,
311376
) -> Biomolecule:
312377
"""Takes a Biopython structure/model mmCIF object and creates a `Biomolecule` instance.
313378
@@ -543,7 +608,9 @@ def _from_mmcif_object(
543608

544609

545610
@typecheck
546-
def from_mmcif_string(mmcif_str: str, file_id: str, chain_ids: Optional[Set[str]] = None) -> Biomolecule:
611+
def from_mmcif_string(
612+
mmcif_str: str, file_id: str, chain_ids: Optional[Set[str]] = None
613+
) -> Biomolecule:
547614
"""Takes a mmCIF string and constructs a `Biomolecule` object.
548615
549616
WARNING: All non-standard residue types will be converted into UNK. All

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
FeatureDict = MutableMapping[str, np.ndarray]
1515

1616

17-
def make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
17+
def make_sequence_features(sequence: str, description: str) -> FeatureDict:
1818
"""Construct a feature dict of sequence features."""
1919
features = {}
20-
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
2120
features["domain_name"] = np.array([description.encode("utf-8")], dtype=object)
22-
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
2321
features["sequence"] = np.array([sequence.encode("utf-8")], dtype=object)
2422
return features
2523

@@ -101,22 +99,24 @@ def make_mmcif_features(
10199
mmcif_object.chain_to_seqres[chain_id] for chain_id in mmcif_object.chain_to_seqres
102100
)
103101
description = mmcif_object.file_id
104-
num_res = len(input_sequence)
105102

106103
mmcif_feats = {}
107104

108105
mmcif_feats.update(
109106
make_sequence_features(
110107
sequence=input_sequence,
111108
description=description,
112-
num_res=num_res,
113109
)
114110
)
115111

116112
# As necessary, expand the first bioassembly/model sequence and structure, to obtain a biologically relevant complex (AF3 Supplement, Section 2.1).
117113
# Reference: https://github.com/biotite-dev/biotite/blob/1045f43f80c77a0dc00865e924442385ce8f83ab/src/biotite/structure/io/pdbx/convert.py#L1441
118114

119-
assembly = _from_mmcif_object(mmcif_object) if "assembly" in description else get_assembly(_from_mmcif_object(mmcif_object))
115+
assembly = (
116+
_from_mmcif_object(mmcif_object)
117+
if "assembly" in description
118+
else get_assembly(_from_mmcif_object(mmcif_object))
119+
)
120120

121121
mmcif_feats["all_atom_positions"] = assembly.atom_positions
122122
mmcif_feats["all_atom_mask"] = assembly.atom_mask
@@ -128,6 +128,8 @@ def make_mmcif_features(
128128
mmcif_feats["residue_index"] = assembly.residue_index
129129
mmcif_feats["restype"] = assembly.restype
130130

131+
mmcif_feats["bonds"] = mmcif_object.bonds
132+
131133
mmcif_feats["resolution"] = np.array([mmcif_object.header["resolution"]], dtype=np.float32)
132134

133135
mmcif_feats["release_date"] = np.array(
@@ -148,12 +150,15 @@ def make_mmcif_features(
148150
file_id=file_id,
149151
)
150152
mmcif_feats, assembly = make_mmcif_features(mmcif_object)
153+
# cropped_assembly = assembly.contiguous_crop(384)
151154
mmcif_string = to_mmcif(
152155
assembly,
156+
# cropped_assembly,
153157
file_id=file_id,
154158
gapless_poly_seq=True,
155159
insert_alphafold_mmcif_metadata=False,
156160
unique_res_atom_names=assembly.unique_res_atom_names,
161+
# unique_res_atom_names=cropped_assembly.unique_res_atom_names,
157162
)
158163
with open(os.path.basename(filepath).replace(".cif", "_reconstructed.cif"), "w") as f:
159164
f.write(mmcif_string)

0 commit comments

Comments
 (0)