@@ -48,7 +48,7 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4848 return base_model , lora_model , tokenizer , optimizer
4949
5050
51- def esm_tokenize_sequences (sequences , tokenizer , max_length , verbose = True ):
51+ def tokenize_sequences (sequences , tokenizer , max_length , verbose = True ):
5252 tokenized_sequences = []
5353 for seq in tqdm (sequences , desc = 'Tokenizing sequences for ESM modeling' , disable = not verbose ):
5454 encoded_sequence , attention_mask = tokenizer (
@@ -154,18 +154,49 @@ def esm_unmasked_wt_score(
154154 ):
155155 if device is None :
156156 device = get_device ()
157- wt_input_ids = wt_input_ids .unsqueeze (0 )
157+ if wt_input_ids .dim () == 1 :
158+ wt_input_ids = wt_input_ids .unsqueeze (0 )
159+ structure_input_ids = kwargs .get ("structure_input_ids" , None )
158160 attention_masks = torch .Tensor (np .full (
159161 shape = np .shape (wt_input_ids ), fill_value = attention_mask )).to (torch .int64 )
160162 if train :
161- outputs = model (wt_input_ids .to (device ), attention_masks .to (device ),
162- output_hidden_states = False )
163+ if structure_input_ids is not None :
164+ outputs = model (
165+ input_ids = wt_input_ids .to (device ),
166+ attention_mask = attention_masks .to (device ),
167+ ss_input_ids = structure_input_ids .to (device )
168+ )
169+ else :
170+ outputs = model (
171+ wt_input_ids .to (device ),
172+ attention_masks .to (device ),
173+ output_hidden_states = False
174+ )
163175 else :
164176 with torch .no_grad ():
165- outputs = model (wt_input_ids .to (device ), attention_masks .to (device ),
166- output_hidden_states = False )
177+ if structure_input_ids is not None :
178+ outputs = model (
179+ input_ids = wt_input_ids .to (device ),
180+ attention_mask = attention_masks .to (device ),
181+ ss_input_ids = structure_input_ids .to (device )
182+ )
183+ else :
184+ outputs = model (
185+ wt_input_ids .to (device ),
186+ attention_masks .to (device ),
187+ output_hidden_states = False
188+ )
189+
167190 logits = outputs .logits
168- token_probs = torch .log_softmax (logits , dim = - 1 ).squeeze (0 )
191+ logits = logits .squeeze (0 ) # remove batch dim
192+ #print('logits.shape:', logits.shape)
193+ # Better make sure that special tokens are always removed / masked
194+ # and only pure amino acid sequence tokens are present / unmasked
195+ #logits = logits[1:-1] # drop CLS/EOS
196+ token_probs = torch .log_softmax (logits , dim = - 1 )
197+ assert len (tokenized_sequences [0 ]) == token_probs .shape [0 ], f"{ len (tokenized_sequences [0 ])} != { token_probs .shape [0 ]} "
198+ #print('token_probs.shape:', token_probs.shape)
199+
169200 for i_s , tokenized_seq in enumerate (tokenized_sequences ):
170201 for i_aa , aa in enumerate (tokenized_seq ):
171202 # alternative: use Tensor.index_select() function
@@ -417,7 +448,7 @@ def esm_train(
417448def esm_setup (sequences , device : str | None = None , verbose : bool = True ):
418449 esm_base_model , esm_lora_model , esm_tokenizer , esm_optimizer = get_esm_models ()
419450 esm_base_model = esm_base_model .to (device )
420- x_esm , esm_attention_mask = esm_tokenize_sequences (
451+ x_esm , esm_attention_mask = tokenize_sequences (
421452 sequences , esm_tokenizer , max_length = len (sequences [0 ]), verbose = verbose )
422453 llm_dict_esm = {
423454 'esm1v' : {
0 commit comments