Skip to content

Commit 196d662

Browse files
committed
reader: add collator to esm reader
1 parent c89f26d commit 196d662

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

chebai_proteins/preprocessing/reader.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader
99
from esm import Alphabet
1010
from esm.model.esm2 import ESM2
11-
from esm.pretrained import (
12-
_has_regression_weights,
13-
load_model_and_alphabet_core,
14-
load_model_and_alphabet_local,
15-
)
11+
from esm.pretrained import _has_regression_weights # noqa
12+
from esm.pretrained import load_model_and_alphabet_core, load_model_and_alphabet_local
1613

1714

1815
class ProteinDataReader(DataReader):
@@ -24,7 +21,7 @@ class ProteinDataReader(DataReader):
2421
Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure
2522
2623
Args:
27-
collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator.
24+
collator_kwargs (Optional[Dict[str, Any]]): Optional dict of keyword arguments for configuring the collator.
2825
token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically.
2926
kwargs: Additional keyword arguments.
3027
"""
@@ -132,7 +129,7 @@ def _read_data(self, raw_data: str) -> List[int]:
132129

133130
def on_finish(self) -> None:
134131
"""
135-
Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
132+
Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
136133
"""
137134
with open(self.token_path, "w") as pk:
138135
print(f"Saving {len(self.cache)} tokens to {self.token_path}...")
@@ -158,6 +155,8 @@ class ESM2EmbeddingReader(DataReader):
158155
159156
"""
160157

158+
COLLATOR = RaggedCollator
159+
161160
# https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53
162161
_MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt"
163162
_REGRESSION_URL = (
@@ -270,12 +269,12 @@ def load_hub_workaround(self, url) -> torch.Tensor:
270269
)
271270
except HTTPError as e:
272271
raise Exception(
273-
f"Could not load {url}. Did you specify the correct model name?"
272+
f"Could not load {url}. Did you specify the correct model name? \n Error: {e}"
274273
)
275274
return data
276275

277-
@staticmethod
278-
def name() -> str:
276+
@classmethod
277+
def name(cls) -> str:
279278
"""
280279
Returns the name of the data reader. This method identifies the specific type of data reader.
281280

0 commit comments

Comments
 (0)