Skip to content

Commit 62a3f45

Browse files
committed
parameter for maximum length (default: 1002)
1 parent 06ab981 commit 62a3f45

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import requests
2525
import torch
2626
from Bio import SwissProt
27+
from torch.utils.data import DataLoader
2728

2829
from chebai.preprocessing import reader as dr
2930
from chebai.preprocessing.datasets.base import _DynamicDataset
@@ -74,6 +75,11 @@ def __init__(self, **kwargs):
7475
self.go_branch: str = self._get_go_branch(**kwargs)
7576
super(_GOUniProtDataExtractor, self).__init__(**kwargs)
7677

78+
self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
79+
assert (
80+
self.max_sequence_length >= 1
81+
), "Max sequence length should be greater than or equal to 1."
82+
7783
@classmethod
7884
def _get_go_branch(cls, **kwargs) -> str:
7985
"""
@@ -397,15 +403,14 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:
397403
}
398404
# https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8
399405
AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"}
400-
MAX_LENGTH = 1002
401406

402407
for record in swiss_data:
403408
if record.data_class != "Reviewed":
404409
# To consider only manually-annotated swiss data
405410
continue
406411

407-
if not record.sequence or record.sequence_length > MAX_LENGTH:
408-
# Consider protein with only sequence representation and a maximum length of 1002
412+
if not record.sequence:
413+
# Consider protein with only sequence representation
409414
continue
410415

411416
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
@@ -524,6 +529,29 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
524529

525530
return df_train, df_val, df_test
526531

532+
# ------------------------------ Phase: DataLoaders -----------------------------------
533+
def dataloader(self, kind: str, **kwargs) -> DataLoader:
534+
"""
535+
Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test).
536+
537+
This method overrides the dataloader method from the superclass. After fetching the dataset from the
538+
superclass, it truncates the 'features' of each data instance to a maximum length specified by
539+
`self.max_sequence_length`.
540+
541+
Args:
542+
kind (str): The kind of data to load (e.g., 'train', 'val', 'test').
543+
**kwargs: Additional keyword arguments passed to the superclass dataloader method.
544+
545+
Returns:
546+
DataLoader: A DataLoader object with the truncated sequences.
547+
"""
548+
dataloader = super().dataloader(kind, **kwargs)
549+
550+
# Truncate the 'features' to max_sequence_length for each instance
551+
for instance in dataloader.dataset:
552+
instance["features"] = instance["features"][: self.max_sequence_length]
553+
return dataloader
554+
527555
# ------------------------------ Phase: Raw Properties -----------------------------------
528556
@property
529557
def base_dir(self) -> str:
@@ -619,7 +647,7 @@ def select_classes(
619647
- The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass.
620648
"""
621649
# Retrieve the DataFrame containing GO annotations per protein from the keyword arguments
622-
data_df: pd.DataFrame = kwargs.get("data_df", None)
650+
data_df = kwargs.get("data_df", None)
623651
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
624652
raise AttributeError(
625653
"The 'data_df' argument must be provided and must be a non-empty pandas DataFrame."

0 commit comments

Comments
 (0)