Skip to content

Commit 282bc09

Browse files
committed
Merge branch 'dev' into additional_unit_tests
2 parents b915b0d + 20764f7 commit 282bc09

File tree

2 files changed

+308
-22
lines changed

2 files changed

+308
-22
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt
99
# https://www.uniprot.org/uniprotkb
1010

11-
__all__ = ["GOUniProtOver250", "GOUniProtOver50"]
11+
__all__ = [
12+
"GOUniProtOver250",
13+
"GOUniProtOver50",
14+
"EXPERIMENTAL_EVIDENCE_CODES",
15+
"AMBIGUOUS_AMINO_ACIDS",
16+
]
1217

1318
import gzip
1419
import itertools
@@ -25,11 +30,24 @@
2530
import requests
2631
import torch
2732
from Bio import SwissProt
28-
from torch.utils.data import DataLoader
2933

3034
from chebai.preprocessing import reader as dr
3135
from chebai.preprocessing.datasets.base import _DynamicDataset
3236

37+
EXPERIMENTAL_EVIDENCE_CODES = {
38+
"EXP",
39+
"IDA",
40+
"IPI",
41+
"IMP",
42+
"IGI",
43+
"IEP",
44+
"TAS",
45+
"IC",
46+
}
47+
48+
# https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8
49+
AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"}
50+
3351

3452
class _GOUniProtDataExtractor(_DynamicDataset, ABC):
3553
"""
@@ -343,13 +361,15 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
343361
data_df = self._get_swiss_to_go_mapping()
344362
# add ancestors to go ids
345363
data_df["go_ids"] = data_df["go_ids"].apply(
346-
lambda go_ids: list(
347-
itertools.chain.from_iterable(
348-
[
349-
[go_id] + list(g.predecessors(go_id))
350-
for go_id in go_ids
351-
if go_id in g.nodes
352-
]
364+
lambda go_ids: sorted(
365+
set(
366+
itertools.chain.from_iterable(
367+
[
368+
[go_id] + list(g.predecessors(go_id))
369+
for go_id in go_ids
370+
if go_id in g.nodes
371+
]
372+
)
353373
)
354374
)
355375
)
@@ -410,19 +430,6 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:
410430
)
411431
)
412432

413-
EXPERIMENTAL_EVIDENCE_CODES = {
414-
"EXP",
415-
"IDA",
416-
"IPI",
417-
"IMP",
418-
"IGI",
419-
"IEP",
420-
"TAS",
421-
"IC",
422-
}
423-
# https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8
424-
AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"}
425-
426433
for record in swiss_data:
427434
if record.data_class != "Reviewed":
428435
# To consider only manually-annotated swiss data
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
__all__ = ["SwissProteinPretrain"]
2+
3+
import os
4+
from abc import ABC
5+
from collections import OrderedDict
6+
from typing import Any, Dict, Generator, List, Tuple
7+
8+
import networkx as nx
9+
import pandas as pd
10+
import torch
11+
from Bio import SwissProt
12+
from sklearn.model_selection import train_test_split
13+
14+
from chebai.preprocessing.datasets.base import _DynamicDataset
15+
from chebai.preprocessing.datasets.go_uniprot import (
16+
AMBIGUOUS_AMINO_ACIDS,
17+
EXPERIMENTAL_EVIDENCE_CODES,
18+
GOUniProtOver250,
19+
)
20+
from chebai.preprocessing.reader import ProteinDataReader
21+
22+
23+
class _ProteinPretrainingData(_DynamicDataset, ABC):
24+
"""
25+
Data module for pretraining protein sequences, specifically designed for Swiss-UniProt data. It includes methods for
26+
data preparation, loading, and dynamic splitting of protein sequences.
27+
The data is parsed and filtered to only select proteins with no associated `valid` Gene Ontology (GO) labels.
28+
A valid GO label is the one which has one of evidence codes defined in `EXPERIMENTAL_EVIDENCE_CODES`.
29+
"""
30+
31+
_ID_IDX: int = 0
32+
_DATA_REPRESENTATION_IDX: int = 1 # Index of `sequence` column
33+
34+
def __init__(self, **kwargs):
35+
"""
36+
Initializes the data module with any GOUniProt extractor class object.
37+
38+
Args:
39+
**kwargs: Additional arguments for the superclass initialization.
40+
"""
41+
self._go_uniprot_extractor = GOUniProtOver250()
42+
assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES
43+
44+
self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
45+
assert (
46+
self.max_sequence_length >= 1
47+
), "Max sequence length should be greater than or equal to 1."
48+
49+
super(_ProteinPretrainingData, self).__init__(**kwargs)
50+
51+
if self.reader.n_gram is not None:
52+
assert self.max_sequence_length >= self.reader.n_gram, (
53+
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
54+
f"or equal to n_gram ({self.reader.n_gram})."
55+
)
56+
57+
# ------------------------------ Phase: Prepare data -----------------------------------
58+
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
59+
"""
60+
Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data
61+
for further use.
62+
63+
Args:
64+
*args: Additional positional arguments.
65+
**kwargs: Additional keyword arguments.
66+
"""
67+
processed_name = self.processed_dir_main_file_names_dict["data"]
68+
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
69+
print("Missing processed data file (`data.pkl` file)")
70+
os.makedirs(self.processed_dir_main, exist_ok=True)
71+
self._download_required_data()
72+
protein_df = self._parse_protein_data_for_pretraining()
73+
self.save_processed(protein_df, processed_name)
74+
75+
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
76+
# method not required as no Swiss-UniProt has no ontological data
77+
pass
78+
79+
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
80+
# method not required as no Swiss-UniProt has no ontological data
81+
pass
82+
83+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
84+
# method not required as no Swiss-UniProt has no ontological data
85+
pass
86+
87+
def _download_required_data(self) -> str:
88+
"""
89+
Downloads the required Swiss-Prot data using the GOUniProt extractor class.
90+
91+
Returns:
92+
str: Path to the downloaded data.
93+
"""
94+
return self._go_uniprot_extractor._download_swiss_uni_prot_data()
95+
96+
def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
97+
"""
98+
Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid
99+
Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code
100+
(EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC).
101+
102+
The DataFrame includes the following columns:
103+
- "swiss_id": The unique identifier for each Swiss-Prot record.
104+
- "sequence": The protein sequence.
105+
106+
Note:
107+
We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.`
108+
109+
Returns:
110+
pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO.
111+
"""
112+
print("Parsing swiss uniprot raw data....")
113+
114+
swiss_ids, sequences = [], []
115+
116+
swiss_data = SwissProt.parse(
117+
open(
118+
os.path.join(
119+
self._go_uniprot_extractor.raw_dir,
120+
self._go_uniprot_extractor.raw_file_names_dict["SwissUniProt"],
121+
),
122+
"r",
123+
)
124+
)
125+
126+
for record in swiss_data:
127+
if record.data_class != "Reviewed":
128+
# To consider only manually-annotated swiss data
129+
continue
130+
131+
if not record.sequence:
132+
# Consider protein with only sequence representation
133+
continue
134+
135+
if len(record.sequence) > self.max_sequence_length:
136+
# Consider protein with only sequence length not greater than max seq. length
137+
continue
138+
139+
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
140+
# Skip proteins with ambiguous amino acid codes
141+
continue
142+
143+
has_valid_associated_go_label = False
144+
for cross_ref in record.cross_references:
145+
if cross_ref[0] == self._go_uniprot_extractor._GO_DATA_INIT:
146+
147+
if len(cross_ref) <= 3:
148+
# No evidence code
149+
continue
150+
151+
# https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66
152+
evidence_code = cross_ref[3].split(":")[0]
153+
if evidence_code in EXPERIMENTAL_EVIDENCE_CODES:
154+
has_valid_associated_go_label = True
155+
break
156+
157+
if has_valid_associated_go_label:
158+
# Skip proteins which has at least one associated go label
159+
continue
160+
161+
swiss_ids.append(record.entry_name)
162+
sequences.append(record.sequence)
163+
164+
data_dict = OrderedDict(
165+
swiss_id=swiss_ids, # swiss_id column at index 0
166+
sequence=sequences, # Sequence column at index 1
167+
)
168+
169+
return pd.DataFrame(data_dict)
170+
171+
# ------------------------------ Phase: Setup data -----------------------------------
172+
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
173+
"""
174+
Loads data from a pickled file and yields individual dictionaries for each row.
175+
176+
The pickled file is expected to contain rows with the following structure:
177+
- Data at row index `self._ID_IDX`: ID of go data instance
178+
- Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein
179+
180+
This method is used by `_load_data_from_file` to generate dictionaries that are then
181+
processed and converted into a list of dictionaries containing the features and labels.
182+
183+
Args:
184+
input_file_path (str): The path to the pickled input file.
185+
186+
Yields:
187+
Dict[str, Any]: A dictionary containing:
188+
- `features` (str): The sequence data from the file.
189+
- `ident` (Any): The identifier from row index 0.
190+
- `labels`: Set to None
191+
"""
192+
with open(input_file_path, "rb") as input_file:
193+
df = pd.read_pickle(input_file)
194+
for row in df.values:
195+
yield dict(
196+
features=row[self._DATA_REPRESENTATION_IDX],
197+
ident=row[self._ID_IDX],
198+
labels=None,
199+
)
200+
201+
# ------------------------------ Phase: Dynamic Splits -----------------------------------
202+
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
203+
"""
204+
Loads encoded data and generates training, validation, and test splits.
205+
206+
This method attempts to load encoded data from a file named `data.pt`. It then splits this data into
207+
training, validation, and test sets.
208+
209+
Raises:
210+
FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or
211+
`setup` methods are called to generate the necessary dataset files.
212+
213+
Returns:
214+
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames:
215+
- Training set
216+
- Validation set
217+
- Test set
218+
"""
219+
try:
220+
filename = self.processed_file_names_dict["data"]
221+
data_go = torch.load(
222+
os.path.join(self.processed_dir, filename), weights_only=False
223+
)
224+
except FileNotFoundError:
225+
raise FileNotFoundError(
226+
f"File data.pt doesn't exists. "
227+
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
228+
)
229+
230+
df_go_data = pd.DataFrame(data_go)
231+
train_df_go, df_test = train_test_split(
232+
df_go_data,
233+
train_size=self.train_split,
234+
random_state=self.dynamic_data_split_seed,
235+
)
236+
237+
# Get all splits
238+
df_train, df_val = train_test_split(
239+
train_df_go,
240+
train_size=self.train_split,
241+
random_state=self.dynamic_data_split_seed,
242+
)
243+
244+
return df_train, df_val, df_test
245+
246+
# ------------------------------ Phase: Raw Properties -----------------------------------
247+
@property
248+
def base_dir(self) -> str:
249+
"""
250+
str: The base directory for pretraining data storage.
251+
"""
252+
return os.path.join(self._go_uniprot_extractor.base_dir, "Pretraining")
253+
254+
@property
255+
def raw_dir(self) -> str:
256+
"""Name of the directory where the raw data is stored."""
257+
return self._go_uniprot_extractor.raw_dir
258+
259+
260+
class SwissProteinPretrain(_ProteinPretrainingData):
261+
"""
262+
Data module for Swiss-Prot protein pretraining, inheriting from `_ProteinPretrainingData`.
263+
This class is specifically designed to handle data processing and loading for Swiss-Prot-based protein datasets.
264+
265+
Attributes:
266+
READER (Type): The data reader class used to load and process protein pretraining data.
267+
"""
268+
269+
READER = ProteinDataReader
270+
271+
@property
272+
def _name(self) -> str:
273+
"""
274+
The name identifier for this data module.
275+
276+
Returns:
277+
str: A string identifier, "SwissProteinPretrain", representing the name of this data module.
278+
"""
279+
return f"Swiss_{self.max_sequence_length}"

0 commit comments

Comments
 (0)