@@ -382,15 +382,24 @@ def name(cls) -> str:
382382 """
383383 return "protein_token"
384384
385- def __init__ (self , * args , ** kwargs ):
385+ def __init__ (self , * args , n_gram : Optional [ int ] = None , ** kwargs ):
386386 """
387387 Initializes the ProteinDataReader, loading existing tokens from the specified token file.
388388
389389 Args:
390390 *args: Additional positional arguments passed to the base class.
391391 **kwargs: Additional keyword arguments passed to the base class.
392392 """
393+ if n_gram is not None :
394+ assert (
395+ int (n_gram ) >= 2
396+ ), "Ngrams must be greater than or equal to 2 if provided."
397+ self .n_gram = int (n_gram )
398+ else :
399+ self .n_gram = None
400+
393401 super ().__init__ (* args , ** kwargs )
402+
394403 # Load the existing tokens from the token file into a cache
395404 with open (self .token_path , "r" ) as pk :
396405 self .cache = [x .strip () for x in pk ]
@@ -405,14 +414,25 @@ def _get_token_index(self, token: str) -> int:
405414 Returns:
406415 int: The index of the token, offset by the predefined EMBEDDING_OFFSET.
407416 """
408- if str (token ) not in self .AA_LETTER :
409- raise KeyError (
410- f"Invalid token '{ token } ' encountered. "
411- f"Please ensure that the input only contains valid amino acids "
412- f"20 Valid natural amino acid notation: { self .AA_LETTER } "
413- f"Refer to the amino acid sequence details here: "
414- f"https://en.wikipedia.org/wiki/Protein_primary_structure"
415- )
417+ error_str = (
418+ f"Please ensure that the input only contains valid amino acids "
419+ f"20 Valid natural amino acid notation: { self .AA_LETTER } "
420+ f"Refer to the amino acid sequence details here: "
421+ f"https://en.wikipedia.org/wiki/Protein_primary_structure"
422+ )
423+
424+ if self .n_gram is None :
425+ # Single-letter amino acid token check
426+ if str (token ) not in self .AA_LETTER :
427+ raise KeyError (f"Invalid token '{ token } ' encountered. " + error_str )
428+ else :
429+ # n-gram token validation, ensure that each component of the n-gram is valid
430+ for aa in token :
431+ if aa not in self .AA_LETTER :
432+ raise KeyError (
433+ f"Invalid token '{ token } ' encountered as part of n-gram { self .n_gram } . "
434+ + error_str
435+ )
416436
417437 if str (token ) not in self .cache :
418438 self .cache .append (str (token ))
@@ -428,7 +448,15 @@ def _read_data(self, raw_data: str) -> List[int]:
428448 Returns:
429449 List[int]: A list of integers representing the indices of the amino acid tokens.
430450 """
431- # In the case of protein sequences, each amino acid is typically represented by a single letter.
451+ if self .n_gram is not None :
452+ # Tokenize the sequence into n-grams
453+ tokens = [
454+ raw_data [i : i + self .n_gram ]
455+ for i in range (len (raw_data ) - self .n_gram + 1 )
456+ ]
457+ return [self ._get_token_index (gram ) for gram in tokens ]
458+
459+ # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation)
432460 return [self ._get_token_index (aa ) for aa in raw_data ]
433461
434462 def on_finish (self ) -> None :
0 commit comments