Skip to content

Commit e38d1ab

Browse files
committed
Merge branch 'protein_prediction' into additional_unit_tests
2 parents 309daed + 383b210 commit e38d1ab

File tree

1 file changed

+9
-38
lines changed

1 file changed

+9
-38
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC):
7373

7474
def __init__(self, **kwargs):
7575
self.go_branch: str = self._get_go_branch(**kwargs)
76-
super(_GOUniProtDataExtractor, self).__init__(**kwargs)
7776

7877
self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
7978
assert (
8079
self.max_sequence_length >= 1
8180
), "Max sequence length should be greater than or equal to 1."
8281

82+
super(_GOUniProtDataExtractor, self).__init__(**kwargs)
83+
8384
if self.reader.n_gram is not None:
8485
assert self.max_sequence_length >= self.reader.n_gram, (
8586
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
@@ -415,8 +416,8 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:
415416
# To consider only manually-annotated swiss data
416417
continue
417418

418-
if not record.sequence:
419-
# Consider protein with only sequence representation
419+
if not record.sequence or len(record.sequence) > self.max_sequence_length:
420+
# Consider protein with only sequence representation and seq. length not greater than max seq. length
420421
continue
421422

422423
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
@@ -537,39 +538,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
537538

538539
return df_train, df_val, df_test
539540

540-
# ------------------------------ Phase: DataLoaders -----------------------------------
541-
def dataloader(self, kind: str, **kwargs) -> DataLoader:
542-
"""
543-
Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test).
544-
545-
This method overrides the dataloader method from the superclass. After fetching the dataset from the
546-
superclass, it truncates the 'features' of each data instance to a maximum length specified by
547-
`self.max_sequence_length`. The truncation is adjusted based on the value of `n_gram` to ensure that
548-
the correct number of amino acids is preserved in the truncated sequences.
549-
550-
Args:
551-
kind (str): The kind of data to load (e.g., 'train', 'val', 'test').
552-
**kwargs: Additional keyword arguments passed to the superclass dataloader method.
553-
554-
Returns:
555-
DataLoader: A DataLoader object with the truncated sequences.
556-
"""
557-
dataloader = super().dataloader(kind, **kwargs)
558-
559-
if self.reader.n_gram is None:
560-
# Truncate the 'features' to max_sequence_length for each instance
561-
truncate_index = self.max_sequence_length
562-
else:
563-
# If n_gram is given, adjust truncation to ensure maximum sequence length refers to the maximum number of
564-
# amino acids in sequence rather than number of n-grams. Eg, Sequence "ABCDEFGHIJ" can form 8 trigrams,
565-
# if max length is 5, then only first 3 trigrams should be considered as they are formed by first 5 letters.
566-
truncate_index = self.max_sequence_length - (self.reader.n_gram - 1)
567-
568-
for instance in dataloader.dataset:
569-
instance["features"] = instance["features"][:truncate_index]
570-
571-
return dataloader
572-
573541
# ------------------------------ Phase: Raw Properties -----------------------------------
574542
@property
575543
def base_dir(self) -> str:
@@ -617,13 +585,16 @@ def _name(self) -> str:
617585
"""
618586
Returns the name of the dataset.
619587
588+
'max_sequence_length' in the name indicates that proteins with sequence lengths exceeding are ignored
589+
in the dataset.
590+
620591
Returns:
621592
str: The dataset name, formatted with the current threshold value and/or given go_branch.
622593
"""
623594
if self.go_branch != self._ALL_GO_BRANCHES:
624-
return f"GO{self.THRESHOLD}_{self.go_branch}"
595+
return f"GO{self.THRESHOLD}_{self.go_branch}_{self.max_sequence_length}"
625596

626-
return f"GO{self.THRESHOLD}"
597+
return f"GO{self.THRESHOLD}_{self.max_sequence_length}"
627598

628599
def select_classes(
629600
self, g: nx.DiGraph, *args: Any, **kwargs: Dict[str, Any]

0 commit comments

Comments
 (0)