Skip to content

Commit 5fb3e2d

Browse files
authored
Add support for suffix .gz gzipped cif and a3m inputs for training (#294)
* support gzip cif and a3m files * also gz a .cif file * also gz a .cif file - reverse the type gzip * add .gz support in detection * add .gz support in detection error msg * force add gzipped a3ms
1 parent ace4968 commit 5fb3e2d

File tree

9 files changed

+48
-24975
lines changed

9 files changed

+48
-24975
lines changed

alphafold3_pytorch/data/mmcif_parsing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io
66
import itertools
77
import logging
8+
import gzip
89
from collections import defaultdict
910
from operator import itemgetter
1011
from beartype.typing import Any, Mapping, Optional, Sequence, Set, Tuple
@@ -14,7 +15,7 @@
1415
from Bio.Data import PDBData
1516

1617
from alphafold3_pytorch.utils.data_utils import is_polymer, is_water, matrix_rotate
17-
18+
from alphafold3_pytorch.data.msa_parsing import is_gzip_file
1819
# Type aliases:
1920
ChainId = str
2021
PdbHeader = Mapping[str, Any]
@@ -763,8 +764,13 @@ def parse_mmcif_object(
763764
filepath: str, file_id: str, auth_chains: bool = True, auth_residues: bool = True
764765
) -> MmcifObject:
765766
"""Parse an mmCIF file into an `MmcifObject` containing a BioPython `Structure` object as well as associated metadata."""
766-
with open(filepath, "r") as f:
767-
mmcif_string = f.read()
767+
768+
if is_gzip_file(filepath):
769+
with gzip.open(filepath, "r") as f:
770+
mmcif_string = f.read()
771+
else:
772+
with open(filepath, "r") as f:
773+
mmcif_string = f.read()
768774

769775
parsing_result = parse(
770776
file_id=file_id,

alphafold3_pytorch/data/msa_parsing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import re
88
import string
9+
import binascii
910

1011
import hashlib
1112
from cachetools import cached, LRUCache
@@ -273,3 +274,20 @@ def parse_a3m(a3m_string: str, msa_type: MSA_TYPE) -> Msa:
273274
descriptions=descriptions,
274275
msa_type=msa_type,
275276
)
277+
278+
@typecheck
279+
def is_gzip_file(f: str) -> bool:
280+
"""Checks whether an input file (i.e an a3m MSA file) is gzipped
281+
282+
Method copied from Phispy see https://github.com/linsalrob/PhiSpy/blob/master/PhiSpyModules/helper_functions.py
283+
284+
This is an elegant solution to test whether a file is gzipped by reading the first two characters.
285+
286+
Args:
287+
f (str): The file to test.
288+
289+
Returns:
290+
bool: True if the file is gzip compressed, otherwise False.
291+
"""
292+
with open(f, "rb") as i:
293+
return binascii.hexlify(i.read(2)) == b"1f8b"

alphafold3_pytorch/inputs.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import glob
55
import json
66
import os
7+
import gzip
78
import random
89
import statistics
910
import traceback
@@ -2203,9 +2204,9 @@ def __post_init__(self):
22032204
if exists(self.mmcif_filepath):
22042205
if not os.path.exists(self.mmcif_filepath):
22052206
raise FileNotFoundError(f"mmCIF file not found: {self.mmcif_filepath}.")
2206-
if not self.mmcif_filepath.endswith(".cif"):
2207+
if not (self.mmcif_filepath.endswith(".cif") or self.mmcif_filepath.endswith(".cif.gz")):
22072208
raise ValueError(
2208-
f"mmCIF file `{self.mmcif_filepath}` must have a `.cif` file extension."
2209+
f"mmCIF file `{self.mmcif_filepath}` must have a `.cif` or `.cif.gz` file extension."
22092210
)
22102211
elif not exists(self.biomol):
22112212
raise ValueError("Either an mmCIF file or a `Biomolecule` object must be provided.")
@@ -2825,9 +2826,9 @@ def load_msa_from_msa_dir(
28252826
msa_fpath_pattern = ""
28262827
if exists(msa_dir):
28272828
msa_fpath_pattern = (
2828-
os.path.join(msa_dir, f"{pdb_id.split('-assembly1')[0]}_*", "a3m", "*.a3m")
2829+
os.path.join(msa_dir, f"{pdb_id.split('-assembly1')[0]}_*", "a3m*")
28292830
if distillation
2830-
else os.path.join(msa_dir, f"{file_id}{chain_id}_*.a3m")
2831+
else os.path.join(msa_dir, f"{file_id}{chain_id}_*.a3m*")
28312832
)
28322833
msa_fpaths = glob.glob(msa_fpath_pattern)
28332834

@@ -2844,11 +2845,19 @@ def load_msa_from_msa_dir(
28442845
# into the MSAs as unknown amino acid residues.
28452846
chain_msas = []
28462847
for msa_fpath in msa_fpaths:
2847-
with open(msa_fpath, "r") as f:
2848-
msa = f.read()
2849-
msa = msa_parsing.parse_a3m(msa, chain_msa_type)
2850-
if len(chain_sequence) == len(msa.sequences[0]):
2851-
chain_msas.append(msa)
2848+
if msa_parsing.is_gzip_file(msa_fpath):
2849+
with gzip.open(msa_fpath, "r") as f:
2850+
msa = f.read()
2851+
msa = msa_parsing.parse_a3m(msa, chain_msa_type)
2852+
if len(chain_sequence) == len(msa.sequences[0]):
2853+
chain_msas.append(msa)
2854+
else:
2855+
with open(msa_fpath, "r") as f:
2856+
msa = f.read()
2857+
msa = msa_parsing.parse_a3m(msa, chain_msa_type)
2858+
if len(chain_sequence) == len(msa.sequences[0]):
2859+
chain_msas.append(msa)
2860+
28522861

28532862
if not chain_msas:
28542863
raise ValueError(
@@ -4304,13 +4313,13 @@ def __init__(
43044313
sampler_pdb_ids = set(self.sampler.mappings.get_column("pdb_id").to_list())
43054314
self.files = {
43064315
os.path.splitext(os.path.basename(filepath.name))[0]: filepath
4307-
for filepath in folder.glob(os.path.join("**", "*.cif"))
4316+
for filepath in folder.glob(os.path.join("**", "*.cif*"))
43084317
if os.path.splitext(os.path.basename(filepath.name))[0] in sampler_pdb_ids
43094318
}
43104319
else:
43114320
self.files = {
43124321
os.path.splitext(os.path.basename(file.name))[0]: file
4313-
for file in folder.glob(os.path.join("**", "*.cif"))
4322+
for file in folder.glob(os.path.join("**", "*.cif*"))
43144323
}
43154324

43164325
if exists(filter_out_pdb_ids):
@@ -4484,7 +4493,7 @@ def __init__(
44844493

44854494
self.files = {
44864495
os.path.splitext(os.path.basename(file.name))[0]: file
4487-
for file in folder.glob(os.path.join("**", "*.cif"))
4496+
for file in folder.glob(os.path.join("**", "*.cif*"))
44884497
if os.path.splitext(os.path.basename(file.name))[0].split("-")[1]
44894498
in self.uniprot_to_pdb_id_mapping
44904499
}

data/pdb_data/data_caches/msa/train_msas/209d-assembly1C_protein.a3m

Lines changed: 0 additions & 4 deletions
This file was deleted.
Binary file not shown.

0 commit comments

Comments
 (0)