88from chebai .preprocessing .reader import EMBEDDING_OFFSET , DataReader
99from esm import Alphabet
1010from esm .model .esm2 import ESM2
11- from esm .pretrained import (
12- _has_regression_weights ,
13- load_model_and_alphabet_core ,
14- load_model_and_alphabet_local ,
15- )
11+ from esm .pretrained import _has_regression_weights # noqa
12+ from esm .pretrained import load_model_and_alphabet_core , load_model_and_alphabet_local
1613
1714
1815class ProteinDataReader (DataReader ):
@@ -24,7 +21,7 @@ class ProteinDataReader(DataReader):
2421 Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure
2522
2623 Args:
27- collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator.
24+ collator_kwargs (Optional[Dict[str, Any]]): Optional dict of keyword arguments for configuring the collator.
2825 token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically.
2926 kwargs: Additional keyword arguments.
3027 """
@@ -132,7 +129,7 @@ def _read_data(self, raw_data: str) -> List[int]:
132129
133130 def on_finish (self ) -> None :
134131 """
135- Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
132+ Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
136133 """
137134 with open (self .token_path , "w" ) as pk :
138135 print (f"Saving { len (self .cache )} tokens to { self .token_path } ..." )
@@ -158,6 +155,8 @@ class ESM2EmbeddingReader(DataReader):
158155
159156 """
160157
158+ COLLATOR = RaggedCollator
159+
161160 # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53
162161 _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt"
163162 _REGRESSION_URL = (
@@ -270,12 +269,12 @@ def load_hub_workaround(self, url) -> torch.Tensor:
270269 )
271270 except HTTPError as e :
272271 raise Exception (
273- f"Could not load { url } . Did you specify the correct model name?"
272+ f"Could not load { url } . Did you specify the correct model name? \n Error: { e } "
274273 )
275274 return data
276275
277- @staticmethod
278- def name () -> str :
276+ @classmethod
277+ def name (cls ) -> str :
279278 """
280279 Returns the name of the data reader. This method identifies the specific type of data reader.
281280
0 commit comments