Skip to content

Commit 7ebb0a2

Browse files
committed
Fix: DCA hybrid was missing + self.beta2 * y_ridge
1 parent 1d63ae3 commit 7ebb0a2

File tree

3 files changed

+21
-42
lines changed

3 files changed

+21
-42
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,10 @@ def hybrid_prediction(
575575
if self.llm_attention_mask is not None:
576576
print('No LLM input for hybrid prediction but the model '
577577
'has been trained using an LLM model input.. '
578-
'Using only DCA for hybridprediction.. This can lead '
578+
'Using only DCA for hybrid prediction.. This can lead '
579579
'to unwanted prediction behavior if the hybrid model '
580580
'is trained including an LLM...')
581-
return self.beta1 * y_dca + self.beta2
581+
return self.beta1 * y_dca + self.beta2 * y_ridge
582582

583583
else:
584584
if self.llm_key == 'prosst':
@@ -615,13 +615,6 @@ def hybrid_prediction(
615615
#desc='Infering LoRA-tuned model',
616616
device=self.device).detach().cpu().numpy()
617617

618-
619-
#y_dca, y_ridge, y_llm, y_llm_lora = (
620-
# reduce_by_batch_modulo(y_dca, batch_size=self.batch_size),
621-
# reduce_by_batch_modulo(y_ridge, batch_size=self.batch_size),
622-
# reduce_by_batch_modulo(y_llm, batch_size=self.batch_size),
623-
# reduce_by_batch_modulo(y_llm_lora, batch_size=self.batch_size)
624-
#)
625618
return self.beta1 * y_dca + self.beta2 * y_ridge + self.beta3 * y_llm + self.beta4 * y_llm_lora
626619

627620
def split_performance(

scripts/ProteinGym_runs/results/dca_esm_and_hybrid_opt_results.csv

Lines changed: 0 additions & 27 deletions
This file was deleted.

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -505,15 +505,28 @@ def plot_csv_data(csv, plot_name):
505505

506506
compute_performances(
507507
mut_data=combined_mut_data,
508-
start_i=0,#start_i,
508+
start_i=start_i,
509509
already_tested_is=already_tested_is
510510
)
511511

512512

513513
with open(out_results_csv, 'r') as fh:
514-
with open(os.path.join(os.path.dirname(__file__), 'results/dca_esm_and_hybrid_opt_results_clean.csv'), 'w') as fh2:
515-
for line in fh:
516-
if not line.split(',')[1].startswith('OOM') and not line.split(',')[1].startswith('X'):
517-
fh2.write(line)
514+
lines = fh.readlines()
515+
clean_out_results_csv = os.path.join(
516+
os.path.dirname(__file__),
517+
'results/dca_esm_and_hybrid_opt_results_clean.csv'
518+
)
519+
with open(clean_out_results_csv, 'w') as fh2:
520+
header = lines[0]
521+
content = lines[1:]
522+
sort_keys = []
523+
for line in content:
524+
sort_keys.append(int(line.split(',')[0]))
525+
content_sorted, sort_keys_sorted = [l for l in zip(*sorted(
526+
zip(content, sort_keys), key=lambda x: x[1]))]
527+
fh2.write(header)
528+
for line in content_sorted:
529+
if not line.split(',')[1].startswith('OOM') and not line.split(',')[1].startswith('X'):
530+
fh2.write(line)
518531

519-
plot_csv_data(csv=os.path.join(os.path.dirname(__file__), 'results/dca_esm_and_hybrid_opt_results_clean.csv'), plot_name='mut_performance')
532+
plot_csv_data(csv=clean_out_results_csv, plot_name='mut_performance')

0 commit comments

Comments
 (0)