@@ -73,13 +73,14 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC):
7373
7474 def __init__ (self , ** kwargs ):
7575 self .go_branch : str = self ._get_go_branch (** kwargs )
76- super (_GOUniProtDataExtractor , self ).__init__ (** kwargs )
7776
7877 self .max_sequence_length : int = int (kwargs .get ("max_sequence_length" , 1002 ))
7978 assert (
8079 self .max_sequence_length >= 1
8180 ), "Max sequence length should be greater than or equal to 1."
8281
82+ super (_GOUniProtDataExtractor , self ).__init__ (** kwargs )
83+
8384 if self .reader .n_gram is not None :
8485 assert self .max_sequence_length >= self .reader .n_gram , (
8586 f"max_sequence_length ({ self .max_sequence_length } ) must be greater than "
@@ -415,8 +416,8 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:
415416 # To consider only manually-annotated swiss data
416417 continue
417418
418- if not record .sequence :
419- # Consider protein with only sequence representation
419+ if not record .sequence or len ( record . sequence ) > self . max_sequence_length :
420+ # Consider protein with only sequence representation and seq. length not greater than max seq. length
420421 continue
421422
422423 if any (aa in AMBIGUOUS_AMINO_ACIDS for aa in record .sequence ):
@@ -537,39 +538,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
537538
538539 return df_train , df_val , df_test
539540
540- # ------------------------------ Phase: DataLoaders -----------------------------------
541- def dataloader (self , kind : str , ** kwargs ) -> DataLoader :
542- """
543- Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test).
544-
545- This method overrides the dataloader method from the superclass. After fetching the dataset from the
546- superclass, it truncates the 'features' of each data instance to a maximum length specified by
547- `self.max_sequence_length`. The truncation is adjusted based on the value of `n_gram` to ensure that
548- the correct number of amino acids is preserved in the truncated sequences.
549-
550- Args:
551- kind (str): The kind of data to load (e.g., 'train', 'val', 'test').
552- **kwargs: Additional keyword arguments passed to the superclass dataloader method.
553-
554- Returns:
555- DataLoader: A DataLoader object with the truncated sequences.
556- """
557- dataloader = super ().dataloader (kind , ** kwargs )
558-
559- if self .reader .n_gram is None :
560- # Truncate the 'features' to max_sequence_length for each instance
561- truncate_index = self .max_sequence_length
562- else :
563- # If n_gram is given, adjust truncation to ensure maximum sequence length refers to the maximum number of
564- # amino acids in sequence rather than number of n-grams. Eg, Sequence "ABCDEFGHIJ" can form 8 trigrams,
565- # if max length is 5, then only first 3 trigrams should be considered as they are formed by first 5 letters.
566- truncate_index = self .max_sequence_length - (self .reader .n_gram - 1 )
567-
568- for instance in dataloader .dataset :
569- instance ["features" ] = instance ["features" ][:truncate_index ]
570-
571- return dataloader
572-
573541 # ------------------------------ Phase: Raw Properties -----------------------------------
574542 @property
575543 def base_dir (self ) -> str :
@@ -617,13 +585,16 @@ def _name(self) -> str:
617585 """
618586 Returns the name of the dataset.
619587
588+ 'max_sequence_length' in the name indicates that proteins with sequence lengths exceeding are ignored
589+ in the dataset.
590+
620591 Returns:
621592 str: The dataset name, formatted with the current threshold value and/or given go_branch.
622593 """
623594 if self .go_branch != self ._ALL_GO_BRANCHES :
624- return f"GO{ self .THRESHOLD } _{ self .go_branch } "
595+ return f"GO{ self .THRESHOLD } _{ self .go_branch } _ { self . max_sequence_length } "
625596
626- return f"GO{ self .THRESHOLD } "
597+ return f"GO{ self .THRESHOLD } _ { self . max_sequence_length } "
627598
628599 def select_classes (
629600 self , g : nx .DiGraph , * args : Any , ** kwargs : Dict [str , Any ]
0 commit comments