2121logger = logging .getLogger ('pypef.llm.esm_lora_tune' )
2222
2323import torch
24- import numpy as np
25- from scipy .stats import spearmanr
26- from tqdm import tqdm
2724
2825from peft import LoraConfig , get_peft_model
2926from transformers import logging as hf_logging
3027hf_logging .set_verbosity_error ()
3128
32- from pypef .utils .helpers import get_device
33- from pypef .plm .utils import corr_loss , get_batches , load_model_and_tokenizer
29+ from pypef .plm .utils import load_model_and_tokenizer
3430
3531
3632def get_esm_models (model = 'facebook/esm1v_t33_650M_UR90S_3' ):
@@ -43,140 +39,3 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4339 lora_model = get_peft_model (base_model , peft_config )
4440 optimizer = torch .optim .Adam (lora_model .parameters (), lr = 0.01 )
4541 return base_model , lora_model , tokenizer , optimizer
46-
47-
48-
49- def get_y_pred_scores (encoded_sequences , attention_masks ,
50- model , device : str | None = None ):
51- if device is None :
52- device = get_device ()
53- model = model .to (device )
54- out = model (encoded_sequences .to (device ), attention_masks .to (device ),
55- output_hidden_states = True )
56- logits = out .logits
57- token_probs = torch .log_softmax (logits , dim = - 1 )
58- for i_s , sequence in enumerate (encoded_sequences ):
59- for i_aa , aa in enumerate (sequence ):
60- # alternative: use Tensor.index_select() function
61- if i_aa == 0 :
62- seq_log_probs = token_probs [i_s , i_aa , aa ].reshape (1 )
63- else :
64- seq_log_probs = torch .cat (
65- (seq_log_probs , token_probs [i_s , i_aa , aa ].reshape (1 )), 0 )
66- if i_s == 0 :
67- log_probs = torch .sum (torch .Tensor (seq_log_probs )).reshape (1 )
68- else :
69- log_probs = torch .cat (
70- (log_probs , torch .sum (torch .Tensor (seq_log_probs )).reshape (1 )), 0 )
71- return log_probs
72-
73-
74- def esm_test (xs , attention_mask , scores , loss_fn , model ,
75- device : str | None = None , verbose : bool = True ):
76- if device is None :
77- device = get_device ()
78- attention_masks = torch .Tensor (np .full (
79- shape = np .shape (xs ), fill_value = attention_mask )).to (torch .int64 )
80- logger .info (f'Infering ESM model for testing using { device .upper ()} device...' )
81- model = model .to (device )
82- xs , attention_masks , scores = (
83- torch .Tensor (xs ).to (device ), attention_masks .to (device ),
84- torch .Tensor (scores ).to (torch .float ).to (device )
85- )
86- pbar_epochs = tqdm (zip (xs , attention_masks , scores ), total = len (xs ), disable = not verbose )
87- for i ,(xs_b , attns_b , scores_b ) in enumerate (pbar_epochs ):
88- xs_b , attns_b = xs_b .to (torch .int64 ), attns_b .to (torch .int64 )
89- with torch .no_grad ():
90- y_preds = get_y_pred_scores (xs_b , attns_b , model , device )
91- if i == 0 :
92- y_preds_total = y_preds
93- scores_total = scores_b
94- else :
95- y_preds_total = torch .cat ((y_preds_total , y_preds ))
96- scores_total = torch .cat ((scores_total , scores_b ))
97- batch_loss = loss_fn (scores_b , y_preds )
98- total_loss = loss_fn (torch .flatten (scores_total ), torch .flatten (y_preds_total ))
99- batch_scorr = spearmanr (scores_b .cpu (), y_preds .cpu ())[0 ]
100- total_scorr = spearmanr (scores_total .cpu (), y_preds_total .cpu ())[0 ]
101- pbar_epochs .set_description (
102- f"Testing: Batch { i + 1 } /{ len (xs )} | Batch loss: { batch_loss :.4f} (SpearCorr: "
103- f"{ batch_scorr :.4f} )| Total loss: { total_loss :.4f} (SpearCorr: { total_scorr :.4f} )" )
104- logger .info (f"Test performance: Loss: { total_loss :.4f} , SpearCorr: { total_scorr :.4f} "
105- f"({ device .upper ()} )" )
106- return torch .flatten (scores ).detach ().cpu (), torch .flatten (y_preds_total ).detach ().cpu ()
107-
108-
109- def esm_infer (xs , attention_mask , model , device : str | None = None , verbose = False ):
110- if device is None :
111- device = get_device ()
112- attention_masks = torch .Tensor (np .full (
113- shape = np .shape (xs ), fill_value = attention_mask )).to (torch .int64 )
114- if verbose :
115- logger .info (f'Infering ESM model for predictions using { device .upper ()} device...' )
116- for i , (xs_b , am_b ) in enumerate (tqdm (
117- zip (xs , attention_masks ), total = len (xs ),
118- desc = f"ESM inference - processing sequences ({ device .upper ()} )" ,
119- disable = not verbose
120- )):
121- xs_b = xs_b .to (torch .int64 )
122- with torch .no_grad ():
123- y_preds = get_y_pred_scores (xs_b , am_b , model , device )
124- if i == 0 :
125- y_preds_total = y_preds
126- else :
127- y_preds_total = torch .cat ((y_preds_total , y_preds ))
128- return torch .flatten (y_preds_total )
129-
130-
131- def esm_train (
132- xs , attention_mask , scores , loss_fn , model , optimizer , n_epochs = 3 ,
133- device : str | None = None , seed : int | None = None ,
134- n_batch_grad_accumulations : int = 1 , verbose : bool = True ,
135- progress_cb = None , abort_cb = None
136- ):
137- if seed is not None :
138- torch .manual_seed (seed )
139- if device is None :
140- device = get_device ()
141- print (f'Training ESM model using { device .upper ()} device '
142- f'(N_Train={ len (torch .flatten (scores ))} )...' )
143- model = model .to (device )
144- attention_masks = torch .Tensor (np .full (
145- shape = np .shape (xs ), fill_value = attention_mask )).to (torch .int64 )
146- xs , attention_masks , scores = xs .to (device ), attention_masks .to (device ), scores .to (device )
147- pbar_epochs = tqdm (range (1 , n_epochs + 1 ), disable = not verbose )
148- loss = np .nan
149- for epoch in pbar_epochs :
150- try :
151- pbar_epochs .set_description (f'Epoch: { epoch } /{ n_epochs } . Loss: { loss .detach ():>1f} ' )
152- except AttributeError :
153- pbar_epochs .set_description (f'Epoch: { epoch } /{ n_epochs } ' )
154- model .train ()
155- pbar_batches = tqdm (
156- zip (xs , attention_masks , scores ),
157- total = len (xs ), leave = False , disable = not verbose
158- )
159- for batch , (xs_b , attns_b , scores_b ) in enumerate (pbar_batches ):
160- if abort_cb and abort_cb ():
161- return
162- xs_b , attns_b = xs_b .to (torch .int64 ), attns_b .to (torch .int64 )
163- y_preds_b = get_y_pred_scores (xs_b , attns_b , model , device = device )
164- loss = loss_fn (scores_b , y_preds_b ) / n_batch_grad_accumulations
165- if progress_cb :
166- progress_cb (epoch - 1 , batch + 1 , len (pbar_epochs ), len (pbar_batches ), loss )
167- loss .backward ()
168- if (batch + 1 ) % n_batch_grad_accumulations == 0 or (batch + 1 ) == len (pbar_batches ):
169- optimizer .step ()
170- optimizer .zero_grad ()
171- pbar_batches .set_description (
172- f"Epoch: { epoch } . Loss: { loss .detach ():>1f} "
173- f"[batch: { batch + 1 } /{ len (xs )} | sequence: "
174- f"{ (batch + 1 ) * len (xs_b ):>5d} /{ len (xs ) * len (xs_b )} ] ({ device .upper ()} )"
175- )
176- if progress_cb :
177- progress_cb (epoch , batch + 1 , len (pbar_epochs ), len (pbar_batches ), loss )
178- y_preds_b = y_preds_b .detach ()
179- model .train (False )
180-
181-
182-
0 commit comments