@@ -405,6 +405,15 @@ def train_llm(self):
405405 train = True ,
406406 device = self .device
407407 )
408+ y_llm_ttrain = self .llm_inference_function (
409+ xs = self .x_llm_ttrain ,
410+ model = self .llm_base_model ,
411+ input_ids = self .input_ids ,
412+ attention_mask = self .llm_attention_mask ,
413+ structure_input_ids = self .structure_input_ids ,
414+ train = True ,
415+ device = self .device
416+ )
408417 elif self .llm_key == 'esm1v' :
409418 y_llm_ttest = self .llm_inference_function (
410419 xs = x_llm_ttest_b ,
@@ -426,9 +435,6 @@ def train_llm(self):
426435 'error, try reducing the batch size or sticking to CPU device...' )
427436
428437 # void function, training model in place
429- # x_sequence_batches, score_batches, loss_fn, model, optimizer,
430- # input_ids, attention_mask, structure_input_ids,
431- # n_epochs=3, device: str | None = None, seed: int | None = None, early_stop: int = 50
432438 if self .llm_key == 'prosst' :
433439 self .llm_train_function (
434440 x_llm_ttrain_b ,
@@ -443,13 +449,20 @@ def train_llm(self):
443449 device = self .device ,
444450 #seed=self.seed
445451 )
452+ y_llm_lora_ttrain = self .llm_inference_function (
453+ xs = self .x_llm_ttrain ,
454+ model = self .llm_model ,
455+ input_ids = self .input_ids ,
456+ attention_mask = self .llm_attention_mask ,
457+ structure_input_ids = self .structure_input_ids ,
458+ device = self .device
459+ )
446460 y_llm_lora_ttest = self .llm_inference_function (
447- x_sequences = self .x_llm_ttest ,
461+ xs = self .x_llm_ttest ,
448462 model = self .llm_model ,
449463 input_ids = self .input_ids ,
450464 attention_mask = self .llm_attention_mask ,
451465 structure_input_ids = self .structure_input_ids ,
452- train = True ,
453466 device = self .device
454467 )
455468 elif self .llm_key == 'esm1v' :
0 commit comments