Skip to content

Commit e6d601c

Browse files
committed
Update hybrid model: works again
1 parent 835b6b7 commit e6d601c

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ def train_llm(self):
405405
train=True,
406406
device=self.device
407407
)
408+
y_llm_ttrain = self.llm_inference_function(
409+
xs=self.x_llm_ttrain,
410+
model=self.llm_base_model,
411+
input_ids=self.input_ids,
412+
attention_mask=self.llm_attention_mask,
413+
structure_input_ids=self.structure_input_ids,
414+
train=True,
415+
device=self.device
416+
)
408417
elif self.llm_key == 'esm1v':
409418
y_llm_ttest = self.llm_inference_function(
410419
xs=x_llm_ttest_b,
@@ -426,9 +435,6 @@ def train_llm(self):
426435
'error, try reducing the batch size or sticking to CPU device...')
427436

428437
# void function, training model in place
429-
# x_sequence_batches, score_batches, loss_fn, model, optimizer,
430-
# input_ids, attention_mask, structure_input_ids,
431-
# n_epochs=3, device: str | None = None, seed: int | None = None, early_stop: int = 50
432438
if self.llm_key == 'prosst':
433439
self.llm_train_function(
434440
x_llm_ttrain_b,
@@ -443,13 +449,20 @@ def train_llm(self):
443449
device=self.device,
444450
#seed=self.seed
445451
)
452+
y_llm_lora_ttrain = self.llm_inference_function(
453+
xs=self.x_llm_ttrain,
454+
model=self.llm_model,
455+
input_ids=self.input_ids,
456+
attention_mask=self.llm_attention_mask,
457+
structure_input_ids=self.structure_input_ids,
458+
device=self.device
459+
)
446460
y_llm_lora_ttest = self.llm_inference_function(
447-
x_sequences=self.x_llm_ttest,
461+
xs=self.x_llm_ttest,
448462
model=self.llm_model,
449463
input_ids=self.input_ids,
450464
attention_mask=self.llm_attention_mask,
451465
structure_input_ids=self.structure_input_ids,
452-
train=True,
453466
device=self.device
454467
)
455468
elif self.llm_key == 'esm1v':

0 commit comments

Comments
 (0)