Skip to content

Commit 7c0243f

Browse files
committed
Add ESM tests
1 parent 21e605f commit 7c0243f

File tree

11 files changed

+9922
-43
lines changed

11 files changed

+9922
-43
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
sleep 3
5151
pip install .[gui]
5252
echo $(which pypef)
53-
python -m pytest ./tests/ -v -m "not main_script_specific" --log-cli-level=INFO
53+
python -m pytest ./tests/ -v -m "not (main_script_specific or requires_gpu)" --log-cli-level=INFO
5454
5555
windows:
5656
name: windows
@@ -88,5 +88,5 @@ jobs:
8888
$env:PYTHONPATH = ""
8989
pip install .[gui]
9090
echo (Get-Command pypef).Source
91-
python -m pytest .\tests -v -m "not main_script_specific" --log-cli-level=INFO
91+
python -m pytest .\tests -v -m "not (main_script_specific or requires_gpu)" --log-cli-level=INFO
9292

datasets/BLAT_ECOLX/BLAT_ECOLX.pdb

Lines changed: 4439 additions & 0 deletions
Large diffs are not rendered by default.

datasets/BLAT_ECOLX/BLAT_ECOLX_Stiffler_2015.csv

Lines changed: 4997 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
>BLAT_ECOLX_Stiffler
2+
MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW

pypef/gui/qt_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from PySide6.QtWidgets import (
1414
QApplication, QPushButton, QTextEdit, QVBoxLayout, QWidget,
1515
QGridLayout, QLabel, QPlainTextEdit, QSlider, QComboBox,
16-
QFileDialog, QProgressBar, QSizePolicy
16+
QFileDialog, QProgressBar
1717
)
1818

1919
from pypef import __version__

pypef/hybrid/hybrid_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
4040
from pypef.plm.esm_lora_tune import esm_setup, get_esm_models
4141
from pypef.plm.prosst_lora_tune import get_prosst_models, prosst_setup
42-
from pypef.plm.inference import llm_embedder, inference
42+
from pypef.plm.inference import llm_tokenizer, inference
4343
from pypef.plm.utils import get_batches
4444

4545
# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
@@ -1030,11 +1030,11 @@ def performance_ls_ts(
10301030
if llm is not None:
10311031
if llm.lower().startswith('esm'):
10321032
llm_dict = esm_setup(train_sequences)
1033-
x_llm_test = llm_embedder(llm_dict, test_sequences)
1033+
x_llm_test = llm_tokenizer(llm_dict, test_sequences)
10341034
elif llm.lower() == 'prosst':
10351035
llm_dict = prosst_setup(
10361036
wt_seq, pdb_file, sequences=train_sequences)
1037-
x_llm_test = llm_embedder(llm_dict, test_sequences)
1037+
x_llm_test = llm_tokenizer(llm_dict, test_sequences)
10381038
else:
10391039
llm_dict = None
10401040
x_llm_test = None
@@ -1080,7 +1080,7 @@ def performance_ls_ts(
10801080
)
10811081
if model.llm_model_input is not None:
10821082
logger.info(f"Found hybrid model with LLM {list(model.llm_model_input.keys())[0]}...")
1083-
x_llm_test = llm_embedder(model.llm_model_input, test_sequences)
1083+
x_llm_test = llm_tokenizer(model.llm_model_input, test_sequences)
10841084
y_test_pred = model.hybrid_prediction(x_test, x_llm_test)
10851085
else:
10861086
y_test_pred = model.hybrid_prediction(x_test)
@@ -1236,7 +1236,7 @@ def predict_ps(
12361236
ys_pred = model.hybrid_prediction(x_test)
12371237
else:
12381238
sequences = [str(seq) for seq in test_sequences]
1239-
x_llm_test = llm_embedder(model.llm_model_input, sequences)
1239+
x_llm_test = llm_tokenizer(model.llm_model_input, sequences)
12401240
ys_pred = model.hybrid_prediction(np.asarray(x_test), np.asarray(x_llm_test))
12411241
for k, y in enumerate(ys_pred):
12421242
all_y_v_pred.append((ys_pred[k], variants[k]))
@@ -1283,7 +1283,7 @@ def predict_ps(
12831283
ys_pred = model.hybrid_prediction(xs)
12841284
else:
12851285
sequences = [str(seq) for seq in sequences]
1286-
xs_llm = llm_embedder(model.llm_model_input, sequences)
1286+
xs_llm = llm_tokenizer(model.llm_model_input, sequences)
12871287
ys_pred = model.hybrid_prediction(np.asarray(xs), np.asarray(xs_llm))
12881288
assert len(xs) == len(variants) == len(ys_pred)
12891289
y_v_pred = zip(ys_pred, variants)
@@ -1343,7 +1343,7 @@ def predict_directed_evolution(
13431343
if model.llm_model_input is None:
13441344
y_pred = model.hybrid_prediction(xs)
13451345
else:
1346-
x_llm = llm_embedder(model.llm_model_input,
1346+
x_llm = llm_tokenizer(model.llm_model_input,
13471347
variant_sequence, verbose=False)
13481348

13491349
y_pred = model.hybrid_prediction(

0 commit comments

Comments
 (0)