@@ -181,7 +181,7 @@ def __init__(
181181 self .truncation_length = truncation_length
182182 self .toks_per_batch = toks_per_batch
183183 self .return_contacts = return_contacts
184- self .repr_layer = repr_layer
184+ self .repr_layer = int ( repr_layer )
185185
186186 self ._model : Optional [ESM2 ] = None
187187 self ._alphabet : Optional [Alphabet ] = None
@@ -355,6 +355,7 @@ def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tenso
355355
356356 References:
357357 https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107
358+ https://github.com/facebookresearch/esm?tab=readme-ov-file#usage-
358359
359360 Returns:
360361 torch.Tensor: Protein embedding from the specified representation layer.
@@ -393,3 +394,16 @@ def on_finish(self) -> None:
393394 None
394395 """
395396 pass
397+
398+
399+ if __name__ == "__main__" :
400+ reader = ProteinDataReader ()
401+ sample_sequence = "MKTFFVAGVILLLLPLVSSQCVNLTTRTQSRGDPTQKARPEPT"
402+ token_indices = reader ._read_data (sample_sequence )
403+ print (f"Token indices for the sequence: { token_indices } " )
404+
405+ esm_reader = ESM2EmbeddingReader (
406+ model_name = "esm2_t6_8M_UR50D" , repr_layer = "6" , device = torch .device ("cpu" )
407+ )
408+ embeddings = esm_reader ._read_data (sample_sequence )
409+ print (f"ESM2 embeddings shape: { len (embeddings )} " )
0 commit comments