Skip to content

Commit 862c8ef

Browse files
committed
scope dataset: add scope abstract code
1 parent e7b3d80 commit 862c8ef

File tree

2 files changed

+381
-0
lines changed

2 files changed

+381
-0
lines changed

chebai/preprocessing/datasets/scope/__init__.py

Whitespace-only changes.
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
import gzip
2+
import itertools
3+
import os
4+
import pickle
5+
import shutil
6+
from abc import ABC
7+
from collections import OrderedDict
8+
from tempfile import NamedTemporaryFile
9+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
10+
11+
import fastobo
12+
import networkx as nx
13+
import pandas as pd
14+
import requests
15+
import torch
16+
from Bio import SeqIO
17+
from Bio.Seq import Seq
18+
19+
from chebai.preprocessing.datasets.base import _DynamicDataset
20+
from chebai.preprocessing.reader import ProteinDataReader
21+
22+
23+
class _SCOPeDataExtractor(_DynamicDataset, ABC):
24+
"""
25+
A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset.
26+
27+
Args:
28+
dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
29+
splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
30+
max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a
31+
default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further
32+
processing.
33+
**kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule.
34+
"""
35+
36+
_GO_DATA_INIT = "GO"
37+
_SWISS_DATA_INIT = "SWISS"
38+
39+
# -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset`
40+
# "swiss_id" at row index 0
41+
# "accession" at row index 1
42+
# "go_ids" at row index 2
43+
# "sequence" at row index 3
44+
# labels starting from row index 4
45+
_ID_IDX: int = 0
46+
_DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column
47+
_LABELS_START_IDX: int = 4
48+
49+
_SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt"
50+
_PDB_SEQUENCE_DATA_URL = (
51+
"https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz"
52+
)
53+
54+
def __init__(
55+
self,
56+
scope_version: float,
57+
scope_version_train: Optional[float] = None,
58+
**kwargs,
59+
):
60+
61+
self.scope_version: float = scope_version
62+
self.scope_version_train: float = scope_version_train
63+
64+
super(_SCOPeDataExtractor, self).__init__(**kwargs)
65+
66+
if self.scope_version_train is not None:
67+
# Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given
68+
# This is to get the data from respective directory related to "scope_version_train"
69+
_init_kwargs = kwargs
70+
_init_kwargs["chebi_version"] = self.scope_version_train
71+
self._scope_version_train_obj = self.__class__(
72+
**_init_kwargs,
73+
)
74+
75+
@staticmethod
76+
def _get_scope_url(data_type: str, version_number: float) -> str:
77+
"""
78+
Generates the URL for downloading SCOPe files.
79+
80+
Args:
81+
data_type (str): The type of data (e.g., 'cla', 'hie', 'des').
82+
version_number (str): The version of the SCOPe file.
83+
84+
Returns:
85+
str: The formatted SCOPe file URL.
86+
"""
87+
return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format(
88+
data_type=data_type, version_number=version_number
89+
)
90+
91+
# ------------------------------ Phase: Prepare data -----------------------------------
92+
def _download_required_data(self) -> str:
93+
"""
94+
Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset.
95+
96+
Returns:
97+
str: Path to the downloaded data.
98+
"""
99+
self._download_pdb_sequence_data()
100+
return self._download_scope_raw_data()
101+
102+
def _download_pdb_sequence_data(self) -> None:
103+
pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"])
104+
os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True)
105+
106+
if not os.path.isfile(pdb_seq_file_path):
107+
print(f"Downloading PDB sequence data....")
108+
109+
# Create a temporary file
110+
with NamedTemporaryFile(delete=False) as tf:
111+
temp_filename = tf.name
112+
print(f"Downloading to temporary file {temp_filename}")
113+
114+
# Download the file
115+
response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True)
116+
with open(temp_filename, "wb") as temp_file:
117+
shutil.copyfileobj(response.raw, temp_file)
118+
119+
print(f"Downloaded to {temp_filename}")
120+
121+
# Unpack the gzipped file
122+
try:
123+
print(f"Unzipping the file....")
124+
with gzip.open(temp_filename, "rb") as f_in:
125+
output_file_path = pdb_seq_file_path
126+
with open(output_file_path, "wb") as f_out:
127+
shutil.copyfileobj(f_in, f_out)
128+
print(f"Unpacked and saved to {output_file_path}")
129+
130+
except Exception as e:
131+
print(f"Failed to unpack the file: {e}")
132+
finally:
133+
# Clean up the temporary file
134+
os.remove(temp_filename)
135+
print(f"Removed temporary file {temp_filename}")
136+
137+
def _download_scope_raw_data(self) -> str:
138+
os.makedirs(self.raw_dir, exist_ok=True)
139+
for data_type in ["CLA", "COM", "HIE", "DES"]:
140+
data_file_name = self.raw_file_names_dict[data_type]
141+
scope_path = os.path.join(self.raw_dir, data_file_name)
142+
if not os.path.isfile(scope_path):
143+
print(f"Missing Scope: {data_file_name} raw data, Downloading...")
144+
r = requests.get(
145+
self._get_scope_url(data_type.lower(), self.scope_version),
146+
allow_redirects=False,
147+
verify=False, # Disable SSL verification
148+
)
149+
r.raise_for_status() # Check if the request was successful
150+
open(scope_path, "wb").write(r.content)
151+
return "dummy/path"
152+
153+
def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
154+
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
155+
for record in SeqIO.parse(
156+
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
157+
):
158+
pdb_id, chain = record.id.split("_")
159+
pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq)
160+
return pdb_chain_seq_mapping
161+
162+
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
163+
print("Extracting class hierarchy...")
164+
165+
# Load and preprocess CLA file
166+
df_cla = pd.read_csv(
167+
os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]),
168+
sep="\t",
169+
header=None,
170+
comment="#",
171+
)
172+
df_cla.columns = [
173+
"sid",
174+
"PDB_ID",
175+
"description",
176+
"sccs",
177+
"sunid",
178+
"ancestor_nodes",
179+
]
180+
df_cla["sunid"] = pd.to_numeric(
181+
df_cla["sunid"], errors="coerce", downcast="integer"
182+
)
183+
df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply(
184+
lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))}
185+
)
186+
df_cla.set_index("sunid", inplace=True)
187+
188+
# Load and preprocess HIE file
189+
df_hie = pd.read_csv(
190+
os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]),
191+
sep="\t",
192+
header=None,
193+
comment="#",
194+
)
195+
df_hie.columns = ["sunid", "parent_sunid", "children_sunids"]
196+
df_hie["sunid"] = pd.to_numeric(
197+
df_hie["sunid"], errors="coerce", downcast="integer"
198+
)
199+
df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int)
200+
df_hie["children_sunids"] = df_hie["children_sunids"].apply(
201+
lambda x: list(map(int, x.split(","))) if x != "-" else []
202+
)
203+
204+
# Initialize directed graph
205+
g = nx.DiGraph()
206+
207+
# Add nodes and edges efficiently
208+
g.add_edges_from(
209+
df_hie[df_hie["parent_sunid"] != -1].apply(
210+
lambda row: (row["parent_sunid"], row["sunid"]), axis=1
211+
)
212+
)
213+
g.add_edges_from(
214+
df_hie.explode("children_sunids")
215+
.dropna()
216+
.apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1)
217+
)
218+
219+
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
220+
221+
node_to_pdb_id = df_cla["PDB_ID"].to_dict()
222+
223+
for node in g.nodes():
224+
pdb_id = node_to_pdb_id[node]
225+
chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {})
226+
227+
# Add nodes and edges for chains in the mapping
228+
for chain, sequence in chain_mapping.items():
229+
chain_node = f"{pdb_id}_{chain}"
230+
g.add_node(chain_node, sequence=sequence)
231+
g.add_edge(node, chain_node)
232+
233+
print("Compute transitive closure...")
234+
return nx.transitive_closure_dag(g)
235+
236+
def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
237+
"""
238+
Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes
239+
Swiss-Prot protein data and their associations with Gene Ontology (GO) terms.
240+
241+
Note:
242+
- GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value
243+
indicates whether a Swiss-Prot protein is associated with that GO term.
244+
- Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins
245+
and GO terms.
246+
247+
Data Format: pd.DataFrame
248+
- Column 0 : swiss_id (Identifier for SwissProt protein)
249+
- Column 1 : Accession of the protein
250+
- Column 2 : GO IDs (associated GO terms)
251+
- Column 3 : Sequence of the protein
252+
- Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the
253+
protein is associated with this GO term.
254+
255+
Args:
256+
g (nx.DiGraph): The class hierarchy graph.
257+
258+
Returns:
259+
pd.DataFrame: The raw dataset created from the graph.
260+
"""
261+
print(f"Processing graph")
262+
263+
data_df = self._get_swiss_to_go_mapping()
264+
# add ancestors to go ids
265+
data_df["go_ids"] = data_df["go_ids"].apply(
266+
lambda go_ids: sorted(
267+
set(
268+
itertools.chain.from_iterable(
269+
[
270+
[go_id] + list(g.predecessors(go_id))
271+
for go_id in go_ids
272+
if go_id in g.nodes
273+
]
274+
)
275+
)
276+
)
277+
)
278+
# Initialize the GO term labels/columns to False
279+
selected_classes = self.select_classes(g, data_df=data_df)
280+
new_label_columns = pd.DataFrame(
281+
False, index=data_df.index, columns=selected_classes
282+
)
283+
data_df = pd.concat([data_df, new_label_columns], axis=1)
284+
285+
# Set True for the corresponding GO IDs in the DataFrame go labels/columns
286+
for index, row in data_df.iterrows():
287+
for go_id in row["go_ids"]:
288+
if go_id in data_df.columns:
289+
data_df.at[index, go_id] = True
290+
291+
# This filters the DataFrame to include only the rows where at least one value in the row from 5th column
292+
# onwards is True/non-zero.
293+
# Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least
294+
# one GO term from the set of the GO terms for the model`
295+
data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
296+
return data_df
297+
298+
# ------------------------------ Phase: Setup data -----------------------------------
299+
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
300+
with open(input_file_path, "rb") as input_file:
301+
df = pd.read_pickle(input_file)
302+
for row in df.values:
303+
labels = row[self._LABELS_START_IDX :].astype(bool)
304+
# chebai.preprocessing.reader.DataReader only needs features, labels, ident, group
305+
# "group" set to None, by default as no such entity for this data
306+
yield dict(
307+
features=row[self._DATA_REPRESENTATION_IDX],
308+
labels=labels,
309+
ident=row[self._ID_IDX],
310+
)
311+
312+
# ------------------------------ Phase: Dynamic Splits -----------------------------------
313+
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
314+
try:
315+
filename = self.processed_file_names_dict["data"]
316+
data_go = torch.load(
317+
os.path.join(self.processed_dir, filename), weights_only=False
318+
)
319+
except FileNotFoundError:
320+
raise FileNotFoundError(
321+
f"File data.pt doesn't exists. "
322+
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
323+
)
324+
325+
df_go_data = pd.DataFrame(data_go)
326+
train_df_go, df_test = self.get_test_split(
327+
df_go_data, seed=self.dynamic_data_split_seed
328+
)
329+
330+
# Get all splits
331+
df_train, df_val = self.get_train_val_splits_given_test(
332+
train_df_go,
333+
df_test,
334+
seed=self.dynamic_data_split_seed,
335+
)
336+
337+
return df_train, df_val, df_test
338+
339+
# ------------------------------ Phase: Raw Properties -----------------------------------
340+
@property
341+
def base_dir(self) -> str:
342+
"""
343+
Returns the base directory path for storing GO-Uniprot data.
344+
345+
Returns:
346+
str: The path to the base directory, which is "data/GO_UniProt".
347+
"""
348+
return os.path.join("data", "SCOPe", f"version_{self.scope_version}")
349+
350+
@property
351+
def raw_file_names_dict(self) -> dict:
352+
"""
353+
Returns a dictionary of raw file names used in data processing.
354+
355+
Returns:
356+
dict: A dictionary mapping dataset names to their respective file names.
357+
For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}.
358+
"""
359+
return {
360+
"CLA": "cla.txt",
361+
"DES": "des.txt",
362+
"HIE": "hie.txt",
363+
"COM": "com.txt",
364+
"PDB": "pdb_sequences.txt",
365+
}
366+
367+
368+
class SCOPE(_SCOPeDataExtractor):
369+
READER = ProteinDataReader
370+
371+
@property
372+
def _name(self) -> str:
373+
return "test"
374+
375+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
376+
pass
377+
378+
379+
if __name__ == "__main__":
380+
scope = SCOPE(scope_version=2.08)
381+
scope._parse_pdb_sequence_file()

0 commit comments

Comments
 (0)