Skip to content

Commit 6b097cc

Browse files
committed
fix esm2 issue
1 parent 3a7a320 commit 6b097cc

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131

132132
super(_GOUniProtDataExtractor, self).__init__(**kwargs)
133133

134-
if self.reader.n_gram is not None:
134+
if hasattr(self.reader, "n_gram") and self.reader.n_gram is not None:
135135
assert self.max_sequence_length >= self.reader.n_gram, (
136136
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
137137
f"or equal to n_gram ({self.reader.n_gram})."

chebai_proteins/preprocessing/reader.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(
181181
self.truncation_length = truncation_length
182182
self.toks_per_batch = toks_per_batch
183183
self.return_contacts = return_contacts
184-
self.repr_layer = repr_layer
184+
self.repr_layer = int(repr_layer)
185185

186186
self._model: Optional[ESM2] = None
187187
self._alphabet: Optional[Alphabet] = None
@@ -355,6 +355,7 @@ def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tenso
355355
356356
References:
357357
https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107
358+
https://github.com/facebookresearch/esm?tab=readme-ov-file#usage-
358359
359360
Returns:
360361
torch.Tensor: Protein embedding from the specified representation layer.
@@ -393,3 +394,16 @@ def on_finish(self) -> None:
393394
None
394395
"""
395396
pass
397+
398+
399+
if __name__ == "__main__":
400+
reader = ProteinDataReader()
401+
sample_sequence = "MKTFFVAGVILLLLPLVSSQCVNLTTRTQSRGDPTQKARPEPT"
402+
token_indices = reader._read_data(sample_sequence)
403+
print(f"Token indices for the sequence: {token_indices}")
404+
405+
esm_reader = ESM2EmbeddingReader(
406+
model_name="esm2_t6_8M_UR50D", repr_layer="6", device=torch.device("cpu")
407+
)
408+
embeddings = esm_reader._read_data(sample_sequence)
409+
print(f"ESM2 embeddings shape: {len(embeddings)}")

0 commit comments

Comments
 (0)