Skip to content

Commit 7f4bfff

Browse files
committed
Making hybrid DCA+LLM DIrectedEvolution much faster
by adding global hybrid model variables. Would be nicer to use a class instead of function globals or decorators?
1 parent 168d9fc commit 7f4bfff

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,6 @@ def predict_ps( # also predicting "pmult" dict directories
14171417

14181418
elif prediction_set is not None: # Predicting single FASTA file sequences
14191419
sequences, variants, _ = get_sequences_from_file(prediction_set)
1420-
print(len(sequences), len(variants))
14211420
# NaNs are already being removed by the called function
14221421
if model_type != 'Hybrid': # statistical DCA model
14231422
xs, variants, _, _, x_wt, *_ = plmc_or_gremlin_encoding(
@@ -1426,12 +1425,10 @@ def predict_ps( # also predicting "pmult" dict directories
14261425
)
14271426
ys_pred = get_delta_e_statistical_model(xs, x_wt)
14281427
else: # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1429-
print(len(variants))
14301428
xs, variants, sequences, *_ = plmc_or_gremlin_encoding(
14311429
variants, sequences, None, params_file,
14321430
threads=threads, verbose=True, substitution_sep=separator
14331431
)
1434-
print('xs len', len(xs), len(variants))
14351432
if model.llm_key is None:
14361433
ys_pred = model.hybrid_prediction(xs)
14371434
else:
@@ -1448,11 +1445,15 @@ def predict_ps( # also predicting "pmult" dict directories
14481445
)
14491446

14501447

1448+
global_hybrid_model = None
1449+
global_hybrid_model_type = None
1450+
1451+
14511452
def predict_directed_evolution(
14521453
encoder: str,
14531454
variant: str,
14541455
variant_sequence: str,
1455-
hybrid_model_data_pkl: str
1456+
hybrid_model_data_pkl: None | str
14561457
) -> Union[str, list]:
14571458
"""
14581459
Perform directed in silico evolution and predict the fitness of a
@@ -1462,8 +1463,14 @@ def predict_directed_evolution(
14621463
cannot be encoded (based on the PLMC params file), returns 'skip'. Else,
14631464
returning the predicted fitness value and the variant name.
14641465
"""
1466+
global global_hybrid_model, global_hybrid_model_type
14651467
if hybrid_model_data_pkl is not None:
1466-
model, model_type = get_model_and_type(hybrid_model_data_pkl)
1468+
if global_hybrid_model is None:
1469+
global_hybrid_model, global_hybrid_model_type = get_model_and_type(
1470+
hybrid_model_data_pkl)
1471+
model, model_type = global_hybrid_model, global_hybrid_model_type
1472+
else:
1473+
model, model_type = global_hybrid_model, global_hybrid_model_type
14671474
else:
14681475
model_type = 'StatisticalModel' # any name != 'Hybrid'
14691476

@@ -1475,7 +1482,7 @@ def predict_directed_evolution(
14751482
return 'skip'
14761483
y_pred = get_delta_e_statistical_model(xs, x_wt)
14771484
else: # model_type == 'Hybrid': Hybrid model input requires params
1478-
#from PLMC or GREMLIN model plus optional LLM input
1485+
# from PLMC or GREMLIN model plus optional LLM input
14791486
xs, variant, variant_sequence, *_ = plmc_or_gremlin_encoding(
14801487
variant, variant_sequence, None, encoder,
14811488
verbose=False, use_global_model=True
@@ -1485,9 +1492,13 @@ def predict_directed_evolution(
14851492
if model.llm_model_input is None:
14861493
x_llm = None
14871494
else:
1488-
x_llm = llm_embedder(model.llm_model_input, variant_sequence, verbose=False)
1495+
x_llm = llm_embedder(model.llm_model_input,
1496+
variant_sequence, verbose=False)
14891497
try:
1490-
y_pred = model.hybrid_prediction(np.atleast_2d(xs), np.atleast_2d(x_llm), verbose=False)[0]
1498+
y_pred = model.hybrid_prediction(
1499+
np.atleast_2d(xs),
1500+
np.atleast_2d(x_llm), verbose=False
1501+
)[0]
14911502
except ValueError as e:
14921503
raise e # TODO: Check sequences / mutations
14931504
# raise SystemError(

pypef/utils/directed_evolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ def in_silico_de(self):
299299
if predictions != 'skip':
300300
logger.info(f"Step {self.de_step_counter + 1}: "
301301
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> "
302-
f"{predictions[0][0]:.3f} WT relative fitness: {predictions[0][0] - wt_prediction[0][0] + add_epsilon:.3f}")
302+
f"{predictions[0][0]:.3f} WT relative fitness: "
303+
f"{predictions[0][0] - wt_prediction[0][0] + add_epsilon:.3f}")
303304
else: # skip if variant cannot be encoded by DCA-based encoding technique
304305
logger.info(f"Step {self.de_step_counter + 1}: "
305306
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> {predictions}")

0 commit comments

Comments
 (0)