22from functools import partial
33
44import torch
5+ from beartype .typing import Literal
56from torch import tensor
67from torch .nn import Module
78
8- from beartype .typing import Literal
9-
10- from alphafold3_pytorch .tensor_typing import (
11- typecheck ,
12- Float ,
13- Int
14- )
15-
16- from alphafold3_pytorch .common .biomolecule import (
17- get_residue_constants ,
18- )
19-
20- from alphafold3_pytorch .inputs import (
21- IS_PROTEIN ,
22- )
9+ from alphafold3_pytorch .common .biomolecule import get_residue_constants
10+ from alphafold3_pytorch .inputs import IS_PROTEIN
11+ from alphafold3_pytorch .tensor_typing import Float , Int , typecheck
2312
2413# functions
2514
@@ -28,41 +17,48 @@ def join(arr, delimiter = ''): # just redo an ugly part of python
2817
2918# constants
3019
31- aa_constants = get_residue_constants (res_chem_index = IS_PROTEIN )
32- restypes_index = dict ( enumerate ( aa_constants .restypes ))
20+ aa_constants = get_residue_constants (res_chem_index = IS_PROTEIN )
21+ restypes = aa_constants .restypes + [ "X" ]
3322
3423# class
3524
25+
3626class ESMWrapper (Module ):
27+ """A wrapper for the ESM model to provide PLM embeddings."""
28+
3729 def __init__ (
3830 self ,
39- esm_name ,
40- repr_layer = 33
31+ esm_name : str ,
32+ repr_layer : int = 33 ,
4133 ):
4234 super ().__init__ ()
4335 import esm
36+
4437 self .repr_layer = repr_layer
4538 self .model , alphabet = esm .pretrained .load_model_and_alphabet_hub (esm_name )
4639 self .batch_converter = alphabet .get_batch_converter ()
4740
4841 self .embed_dim = self .model .embed_dim
49- self .register_buffer (' dummy' , tensor (0 ), persistent = False )
42+ self .register_buffer (" dummy" , tensor (0 ), persistent = False )
5043
5144 @torch .no_grad ()
5245 @typecheck
5346 def forward (
54- self ,
55- aa_ids : Int [ ' b n' ]
56- ) -> Float [ 'b n dpe' ]:
47+ self , aa_ids : Int [ "b n" ] # type: ignore
48+ ) -> Float [ " b n dpe" ]: # type: ignore
49+ """Get PLM embeddings for a batch of (pseudo-)protein sequences.
5750
51+ :param aa_ids: A batch of amino acid residue indices.
52+ :return: The PLM embeddings for the input sequences.
53+ """
5854 device , repr_layer = self .dummy .device , self .repr_layer
5955
6056 sequence_data = [
6157 (
62- f"molecule{ i } " ,
63- join ([restypes_index . get ( i , 'X' ) for i in ids ]),
58+ f"molecule{ mol_idx } " ,
59+ join ([( restypes [ i ] if 0 <= i < len ( restypes ) else "X" ) for i in ids ]),
6460 )
65- for i , ids in enumerate (aa_ids )
61+ for mol_idx , ids in enumerate (aa_ids )
6662 ]
6763
6864 _ , _ , batch_tokens = self .batch_converter (sequence_data )
@@ -80,64 +76,62 @@ def forward(
8076
8177 return plm_embeddings
8278
79+
8380class ProstT5Wrapper (Module ):
81+ """A wrapper for the ProstT5 model to provide PLM embeddings."""
82+
8483 def __init__ (self ):
8584 super ().__init__ ()
86- from transformers import T5Tokenizer , T5EncoderModel
85+ from transformers import T5EncoderModel , T5Tokenizer
8786
88- self .register_buffer (' dummy' , tensor (0 ), persistent = False )
87+ self .register_buffer (" dummy" , tensor (0 ), persistent = False )
8988
90- self .tokenizer = T5Tokenizer .from_pretrained (' Rostlab/ProstT5' , do_lower_case = False )
89+ self .tokenizer = T5Tokenizer .from_pretrained (" Rostlab/ProstT5" , do_lower_case = False )
9190 self .model = T5EncoderModel .from_pretrained ("Rostlab/ProstT5" )
9291 self .embed_dim = 1024
9392
9493 @torch .no_grad ()
9594 @typecheck
9695 def forward (
97- self ,
98- aa_ids : Int [ ' b n' ]
99- ) -> Float [ 'b n dpe' ]:
96+ self , aa_ids : Int [ "b n" ] # type: ignore
97+ ) -> Float [ " b n dpe" ]: # type: ignore
98+ """Get PLM embeddings for a batch of (pseudo-)protein sequences.
10099
100+ :param aa_ids: A batch of amino acid residue indices.
101+ :return: The PLM embeddings for the input sequences.
102+ """
101103 device , seq_len = self .dummy .device , aa_ids .shape [- 1 ]
102104
103105 str_sequences = [
104- join ([restypes_index .get (i , 'X' ) for i in ids ])
105- for i , ids in enumerate (aa_ids )
106+ join ([(restypes [i ] if 0 <= i < len (restypes ) else "X" ) for i in ids ]) for ids in aa_ids
106107 ]
107108
108109 # following the readme at https://github.com/mheinzinger/ProstT5
109110
110- str_sequences = [join (list (re .sub (r"[UZOB]" , "X" , str_seq )), ' ' ) for str_seq in str_sequences ]
111+ str_sequences = [
112+ join (list (re .sub (r"[UZOB]" , "X" , str_seq )), " " ) for str_seq in str_sequences
113+ ]
111114
112115 # encode to ids
113116
114117 inputs = self .tokenizer .batch_encode_plus (
115- str_sequences ,
116- add_special_tokens = True ,
117- padding = "longest" ,
118- return_tensors = 'pt'
118+ str_sequences , add_special_tokens = True , padding = "longest" , return_tensors = "pt"
119119 ).to (device )
120120
121121 # forward through plm
122122
123- embeddings = self .model (
124- inputs .input_ids ,
125- attention_mask = inputs .attention_mask
126- )
123+ embeddings = self .model (inputs .input_ids , attention_mask = inputs .attention_mask )
127124
128125 # remove prefix
129126
130- plm_embedding = embeddings .last_hidden_state [:, 1 : (seq_len + 1 )]
127+ plm_embedding = embeddings .last_hidden_state [:, 1 : (seq_len + 1 )]
131128 return plm_embedding
132129
130+
133131# PLM embedding type and registry
134132
135133PLMRegistry = dict (
136- esm2_t33_650M_UR50D = partial (ESMWrapper , 'esm2_t33_650M_UR50D' ),
137- prostT5 = ProstT5Wrapper
134+ esm2_t33_650M_UR50D = partial (ESMWrapper , "esm2_t33_650M_UR50D" ), prostT5 = ProstT5Wrapper
138135)
139136
140- PLMEmbedding = Literal [
141- "esm2_t33_650M_UR50D" ,
142- "prostT5"
143- ]
137+ PLMEmbedding = Literal ["esm2_t33_650M_UR50D" , "prostT5" ]
0 commit comments