Skip to content

Commit 710d703

Browse files
committed
ignore proteins exceeding max len in preprocessing
1 parent 25177b3 commit 710d703

File tree

1 file changed

+7
-37
lines changed

1 file changed

+7
-37
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:
415415
# To consider only manually-annotated swiss data
416416
continue
417417

418-
if not record.sequence:
419-
# Consider protein with only sequence representation
418+
if not record.sequence or record.sequence > self.max_sequence_length:
419+
# Consider protein with only sequence representation and seq. length not greater than max seq. length
420420
continue
421421

422422
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
@@ -537,39 +537,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
537537

538538
return df_train, df_val, df_test
539539

540-
# ------------------------------ Phase: DataLoaders -----------------------------------
541-
def dataloader(self, kind: str, **kwargs) -> DataLoader:
542-
"""
543-
Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test).
544-
545-
This method overrides the dataloader method from the superclass. After fetching the dataset from the
546-
superclass, it truncates the 'features' of each data instance to a maximum length specified by
547-
`self.max_sequence_length`. The truncation is adjusted based on the value of `n_gram` to ensure that
548-
the correct number of amino acids is preserved in the truncated sequences.
549-
550-
Args:
551-
kind (str): The kind of data to load (e.g., 'train', 'val', 'test').
552-
**kwargs: Additional keyword arguments passed to the superclass dataloader method.
553-
554-
Returns:
555-
DataLoader: A DataLoader object with the truncated sequences.
556-
"""
557-
dataloader = super().dataloader(kind, **kwargs)
558-
559-
if self.reader.n_gram is None:
560-
# Truncate the 'features' to max_sequence_length for each instance
561-
truncate_index = self.max_sequence_length
562-
else:
563-
# If n_gram is given, adjust truncation to ensure maximum sequence length refers to the maximum number of
564-
# amino acids in sequence rather than number of n-grams. Eg, Sequence "ABCDEFGHIJ" can form 8 trigrams,
565-
# if max length is 5, then only first 3 trigrams should be considered as they are formed by first 5 letters.
566-
truncate_index = self.max_sequence_length - (self.reader.n_gram - 1)
567-
568-
for instance in dataloader.dataset:
569-
instance["features"] = instance["features"][:truncate_index]
570-
571-
return dataloader
572-
573540
# ------------------------------ Phase: Raw Properties -----------------------------------
574541
@property
575542
def base_dir(self) -> str:
@@ -617,13 +584,16 @@ def _name(self) -> str:
617584
"""
618585
Returns the name of the dataset.
619586
587+
'max_sequence_length' in the name indicates that proteins with sequence lengths exceeding are ignored
588+
in the dataset.
589+
620590
Returns:
621591
str: The dataset name, formatted with the current threshold value and/or given go_branch.
622592
"""
623593
if self.go_branch != self._ALL_GO_BRANCHES:
624-
return f"GO{self.THRESHOLD}_{self.go_branch}"
594+
return f"GO{self.THRESHOLD}_{self.go_branch}_{self.max_sequence_length}"
625595

626-
return f"GO{self.THRESHOLD}"
596+
return f"GO{self.THRESHOLD}_{self.max_sequence_length}"
627597

628598
def select_classes(
629599
self, g: nx.DiGraph, *args: Any, **kwargs: Dict[str, Any]

0 commit comments

Comments
 (0)