Skip to content

Commit d653f52

Browse files
committed
use TokenIndexerReader for ProteinDataReader
1 parent 5af20c8 commit d653f52

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

chebai_proteins/preprocessing/reader.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
import torch
77
from chebai.preprocessing.collate import RaggedCollator
8-
from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader
8+
from chebai.preprocessing.reader import DataReader, TokenIndexerReader
99
from esm import Alphabet
1010
from esm.model.esm2 import ESM2
1111
from esm.pretrained import _has_regression_weights # noqa
12-
from esm.pretrained import load_model_and_alphabet_core, load_model_and_alphabet_local
12+
from esm.pretrained import load_model_and_alphabet_core
1313

1414

15-
class ProteinDataReader(DataReader):
15+
class ProteinDataReader(TokenIndexerReader):
1616
"""
1717
Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format
1818
suitable for model input by tokenizing them and assigning unique indices to each token.
@@ -30,12 +30,12 @@ class ProteinDataReader(DataReader):
3030

3131
# fmt: off
3232
# 21 natural amino acid notation
33-
AA_LETTER = [
33+
AA_LETTER = {
3434
"A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
3535
"L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
3636
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5
3737
"X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py
38-
]
38+
}
3939
# fmt: on
4040

4141
def name(self) -> str:
@@ -68,10 +68,6 @@ def __init__(self, *args, n_gram: Optional[int] = None, **kwargs):
6868

6969
super().__init__(*args, **kwargs)
7070

71-
# Load the existing tokens from the token file into a cache
72-
with open(self.token_path, "r") as pk:
73-
self.cache = [x.strip() for x in pk]
74-
7571
def _get_token_index(self, token: str) -> int:
7672
"""
7773
Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added.
@@ -102,9 +98,7 @@ def _get_token_index(self, token: str) -> int:
10298
+ error_str
10399
)
104100

105-
if str(token) not in self.cache:
106-
self.cache.append(str(token))
107-
return self.cache.index(str(token)) + EMBEDDING_OFFSET
101+
return super()._get_token_index(token)
108102

109103
def _read_data(self, raw_data: str) -> List[int]:
110104
"""
@@ -127,15 +121,6 @@ def _read_data(self, raw_data: str) -> List[int]:
127121
# If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation)
128122
return [self._get_token_index(aa) for aa in raw_data]
129123

130-
def on_finish(self) -> None:
131-
"""
132-
Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
133-
"""
134-
with open(self.token_path, "w") as pk:
135-
print(f"Saving {len(self.cache)} tokens to {self.token_path}...")
136-
print(f"First 10 tokens: {self.cache[:10]}")
137-
pk.writelines([f"{c}\n" for c in self.cache])
138-
139124

140125
class ESM2EmbeddingReader(DataReader):
141126
"""

0 commit comments

Comments
 (0)