Skip to content

Commit 11c1c34

Browse files
committed
Update dca_llm for DCA / DCA+ESM1v / DCA+ProSST hybrid modeling
1 parent 335256c commit 11c1c34

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,16 @@ def train_llm(self):
396396
#x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
397397
if self.llm_key == 'prosst':
398398
y_llm_ttest = self.llm_inference_function(
399-
x_sequences=self.x_llm_ttest,
399+
xs=self.x_llm_ttest,
400+
model=self.llm_base_model,
401+
input_ids=self.input_ids,
402+
attention_mask=self.llm_attention_mask,
403+
structure_input_ids=self.structure_input_ids,
404+
train=True,
405+
device=self.device
406+
)
407+
y_llm_ttrain = self.llm_inference_function(
408+
xs=self.x_llm_ttrain,
400409
model=self.llm_base_model,
401410
input_ids=self.input_ids,
402411
attention_mask=self.llm_attention_mask,
@@ -442,8 +451,17 @@ def train_llm(self):
442451
device=self.device,
443452
#seed=self.seed
444453
)
454+
y_llm_lora_ttrain = self.llm_inference_function(
455+
xs=self.x_llm_ttrain,
456+
model=self.llm_model,
457+
input_ids=self.input_ids,
458+
attention_mask=self.llm_attention_mask,
459+
structure_input_ids=self.structure_input_ids,
460+
train=True,
461+
device=self.device
462+
)
445463
y_llm_lora_ttest = self.llm_inference_function(
446-
x_sequences=self.x_llm_ttest,
464+
xs=self.x_llm_ttest,
447465
model=self.llm_model,
448466
input_ids=self.input_ids,
449467
attention_mask=self.llm_attention_mask,

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
160160
hybrid_perfs = []
161161
ns_y_test = [len(variants)]
162162
for i_t, train_size in enumerate([100, 200, 1000]):
163-
prosst_lora_model = copy.deepcopy(prosst_lora_model)
164-
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
165-
esm_lora_model = copy.deepcopy(esm_lora_model)
166-
esm_optimizer = torch.optim.Adam(esm_lora_model.parameters(), lr=0.0001)
163+
prosst_lora_model_2 = copy.deepcopy(prosst_lora_model)
164+
prosst_optimizer = torch.optim.Adam(prosst_lora_model_2.parameters(), lr=0.0001)
165+
esm_lora_model_2 = copy.deepcopy(esm_lora_model)
166+
esm_optimizer = torch.optim.Adam(esm_lora_model_2.parameters(), lr=0.0001)
167167
print('\nTRAIN SIZE:', train_size, '\n-------------------------------------------\n')
168168
get_vram()
169169
try:
@@ -194,7 +194,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
194194
llm_dict_prosst = {
195195
'prosst': {
196196
'llm_base_model': prosst_base_model,
197-
'llm_model': prosst_lora_model,
197+
'llm_model': prosst_lora_model_2,
198198
'llm_optimizer': prosst_optimizer,
199199
'llm_train_function': prosst_train,
200200
'llm_inference_function': get_logits_from_full_seqs,
@@ -208,7 +208,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
208208
llm_dict_esm = {
209209
'esm1v': {
210210
'llm_base_model': esm_base_model,
211-
'llm_model': esm_lora_model,
211+
'llm_model': esm_lora_model_2,
212212
'llm_optimizer': esm_optimizer,
213213
'llm_train_function': esm_train,
214214
'llm_inference_function': esm_infer,
@@ -227,6 +227,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
227227
continue
228228
get_vram()
229229
for i_m, method in enumerate([None, llm_dict_esm, llm_dict_prosst]):
230+
print('~~~ ' + ['DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid'][i_m] + ' ~~~')
230231
hm = DCALLMHybridModel(
231232
x_train_dca=np.array(x_dca_train),
232233
y_train=y_train,
@@ -238,7 +239,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
238239

239240
y_test_pred = hm.hybrid_prediction(
240241
x_dca=np.array(x_dca_test),
241-
x_llm=[None, np.array(x_llm_test_esm), np.array(x_llm_test_prosst)][i]
242+
x_llm=[None, np.asarray(x_llm_test_esm), np.asarray(x_llm_test_prosst)][i_m]
242243
)
243244

244245
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
@@ -250,7 +251,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
250251
f'in N_Train = {train_size} and N_Test (N_Total - N_Train).')
251252
hybrid_perfs.append(np.nan)
252253
ns_y_test.append(np.nan)
253-
del prosst_lora_model
254+
del prosst_lora_model_2
255+
del esm_lora_model_2
254256
torch.cuda.empty_cache()
255257
gc.collect()
256258
dt = time.time() - start_time

0 commit comments

Comments
 (0)