55
66import torch
77from chebai .preprocessing .collate import RaggedCollator
8- from chebai .preprocessing .reader import EMBEDDING_OFFSET , DataReader
8+ from chebai .preprocessing .reader import DataReader , TokenIndexerReader
99from esm import Alphabet
1010from esm .model .esm2 import ESM2
1111from esm .pretrained import _has_regression_weights # noqa
12- from esm .pretrained import load_model_and_alphabet_core , load_model_and_alphabet_local
12+ from esm .pretrained import load_model_and_alphabet_core
1313
1414
15- class ProteinDataReader (DataReader ):
15+ class ProteinDataReader (TokenIndexerReader ):
1616 """
1717 Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format
1818 suitable for model input by tokenizing them and assigning unique indices to each token.
@@ -30,12 +30,12 @@ class ProteinDataReader(DataReader):
3030
3131 # fmt: off
3232 # 21 natural amino acid notation
33- AA_LETTER = [
33+ AA_LETTER = {
3434 "A" , "R" , "N" , "D" , "C" , "Q" , "E" , "G" , "H" , "I" ,
3535 "L" , "K" , "M" , "F" , "P" , "S" , "T" , "W" , "Y" , "V" ,
3636 # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5
3737 "X" , # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py
38- ]
38+ }
3939 # fmt: on
4040
4141 def name (self ) -> str :
@@ -68,10 +68,6 @@ def __init__(self, *args, n_gram: Optional[int] = None, **kwargs):
6868
6969 super ().__init__ (* args , ** kwargs )
7070
71- # Load the existing tokens from the token file into a cache
72- with open (self .token_path , "r" ) as pk :
73- self .cache = [x .strip () for x in pk ]
74-
7571 def _get_token_index (self , token : str ) -> int :
7672 """
7773 Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added.
@@ -102,9 +98,7 @@ def _get_token_index(self, token: str) -> int:
10298 + error_str
10399 )
104100
105- if str (token ) not in self .cache :
106- self .cache .append (str (token ))
107- return self .cache .index (str (token )) + EMBEDDING_OFFSET
101+ return super ()._get_token_index (token )
108102
109103 def _read_data (self , raw_data : str ) -> List [int ]:
110104 """
@@ -127,15 +121,6 @@ def _read_data(self, raw_data: str) -> List[int]:
127121 # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation)
128122 return [self ._get_token_index (aa ) for aa in raw_data ]
129123
130- def on_finish (self ) -> None :
131- """
132- Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
133- """
134- with open (self .token_path , "w" ) as pk :
135- print (f"Saving { len (self .cache )} tokens to { self .token_path } ..." )
136- print (f"First 10 tokens: { self .cache [:10 ]} " )
137- pk .writelines ([f"{ c } \n " for c in self .cache ])
138-
139124
140125class ESM2EmbeddingReader (DataReader ):
141126 """
0 commit comments