@@ -40,8 +40,20 @@ def __init__(self, **kwargs):
4040 """
4141 self ._go_uniprot_extractor = GOUniProtOver250 ()
4242 assert self ._go_uniprot_extractor .go_branch == GOUniProtOver250 ._ALL_GO_BRANCHES
43+
44+ self .max_sequence_length : int = int (kwargs .get ("max_sequence_length" , 1002 ))
45+ assert (
46+ self .max_sequence_length >= 1
47+ ), "Max sequence length should be greater than or equal to 1."
48+
4349 super (_ProteinPretrainingData , self ).__init__ (** kwargs )
4450
51+ if self .reader .n_gram is not None :
52+ assert self .max_sequence_length >= self .reader .n_gram , (
53+ f"max_sequence_length ({ self .max_sequence_length } ) must be greater than "
54+ f"or equal to n_gram ({ self .reader .n_gram } )."
55+ )
56+
4557 # ------------------------------ Phase: Prepare data -----------------------------------
4658 def prepare_data (self , * args : Any , ** kwargs : Any ) -> None :
4759 """
@@ -120,6 +132,10 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
120132 # Consider protein with only sequence representation
121133 continue
122134
135+ if len (record .sequence ) > self .max_sequence_length :
136+ # Consider protein with only sequence length not greater than max seq. length
137+ continue
138+
123139 if any (aa in AMBIGUOUS_AMINO_ACIDS for aa in record .sequence ):
124140 # Skip proteins with ambiguous amino acid codes
125141 continue
@@ -260,4 +276,4 @@ def _name(self) -> str:
260276 Returns:
261277 str: A string identifier, "SwissProteinPretrain", representing the name of this data module.
262278 """
263- return "SwissProteinPretrain "
279+ return f"Swiss_ { self . max_sequence_length } "
0 commit comments