|
39 | 39 | from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model |
40 | 40 | from pypef.plm.esm_lora_tune import esm_setup, get_esm_models |
41 | 41 | from pypef.plm.prosst_lora_tune import get_prosst_models, prosst_setup |
42 | | -from pypef.plm.inference import llm_embedder, inference |
| 42 | +from pypef.plm.inference import llm_tokenizer, inference |
43 | 43 | from pypef.plm.utils import get_batches |
44 | 44 |
|
45 | 45 | # sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and |
@@ -1030,11 +1030,11 @@ def performance_ls_ts( |
1030 | 1030 | if llm is not None: |
1031 | 1031 | if llm.lower().startswith('esm'): |
1032 | 1032 | llm_dict = esm_setup(train_sequences) |
1033 | | - x_llm_test = llm_embedder(llm_dict, test_sequences) |
| 1033 | + x_llm_test = llm_tokenizer(llm_dict, test_sequences) |
1034 | 1034 | elif llm.lower() == 'prosst': |
1035 | 1035 | llm_dict = prosst_setup( |
1036 | 1036 | wt_seq, pdb_file, sequences=train_sequences) |
1037 | | - x_llm_test = llm_embedder(llm_dict, test_sequences) |
| 1037 | + x_llm_test = llm_tokenizer(llm_dict, test_sequences) |
1038 | 1038 | else: |
1039 | 1039 | llm_dict = None |
1040 | 1040 | x_llm_test = None |
@@ -1080,7 +1080,7 @@ def performance_ls_ts( |
1080 | 1080 | ) |
1081 | 1081 | if model.llm_model_input is not None: |
1082 | 1082 | logger.info(f"Found hybrid model with LLM {list(model.llm_model_input.keys())[0]}...") |
1083 | | - x_llm_test = llm_embedder(model.llm_model_input, test_sequences) |
| 1083 | + x_llm_test = llm_tokenizer(model.llm_model_input, test_sequences) |
1084 | 1084 | y_test_pred = model.hybrid_prediction(x_test, x_llm_test) |
1085 | 1085 | else: |
1086 | 1086 | y_test_pred = model.hybrid_prediction(x_test) |
@@ -1236,7 +1236,7 @@ def predict_ps( |
1236 | 1236 | ys_pred = model.hybrid_prediction(x_test) |
1237 | 1237 | else: |
1238 | 1238 | sequences = [str(seq) for seq in test_sequences] |
1239 | | - x_llm_test = llm_embedder(model.llm_model_input, sequences) |
| 1239 | + x_llm_test = llm_tokenizer(model.llm_model_input, sequences) |
1240 | 1240 | ys_pred = model.hybrid_prediction(np.asarray(x_test), np.asarray(x_llm_test)) |
1241 | 1241 | for k, y in enumerate(ys_pred): |
1242 | 1242 | all_y_v_pred.append((ys_pred[k], variants[k])) |
@@ -1283,7 +1283,7 @@ def predict_ps( |
1283 | 1283 | ys_pred = model.hybrid_prediction(xs) |
1284 | 1284 | else: |
1285 | 1285 | sequences = [str(seq) for seq in sequences] |
1286 | | - xs_llm = llm_embedder(model.llm_model_input, sequences) |
| 1286 | + xs_llm = llm_tokenizer(model.llm_model_input, sequences) |
1287 | 1287 | ys_pred = model.hybrid_prediction(np.asarray(xs), np.asarray(xs_llm)) |
1288 | 1288 | assert len(xs) == len(variants) == len(ys_pred) |
1289 | 1289 | y_v_pred = zip(ys_pred, variants) |
@@ -1343,7 +1343,7 @@ def predict_directed_evolution( |
1343 | 1343 | if model.llm_model_input is None: |
1344 | 1344 | y_pred = model.hybrid_prediction(xs) |
1345 | 1345 | else: |
1346 | | - x_llm = llm_embedder(model.llm_model_input, |
| 1346 | + x_llm = llm_tokenizer(model.llm_model_input, |
1347 | 1347 | variant_sequence, verbose=False) |
1348 | 1348 |
|
1349 | 1349 | y_pred = model.hybrid_prediction( |
|
0 commit comments