Skip to content

Commit 43ccba5

Browse files
committed
Update directed evo for using hybrid model
Using WT relative variant values now However, negative values are not accepted yet
1 parent bfa33ac commit 43ccba5

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,21 +1471,17 @@ def predict_directed_evolution(
14711471
y_pred = get_delta_e_statistical_model(xs, x_wt)
14721472
else: # model_type == 'Hybrid': Hybrid model input requires params
14731473
#from PLMC or GREMLIN model plus optional LLM input
1474-
print(variant, variant_sequence)
14751474
xs, variant, variant_sequence, *_ = plmc_or_gremlin_encoding(
14761475
variant, variant_sequence, None, encoder,
14771476
verbose=False, use_global_model=True
14781477
)
1479-
print(variant_sequence)
14801478
if not list(xs):
14811479
return 'skip'
14821480
if model.llm_model_input is None:
14831481
x_llm = None
14841482
else:
14851483
x_llm = llm_embedder(model.llm_model_input, variant_sequence)
14861484
try:
1487-
print(np.shape(xs), np.shape(x_llm), np.atleast_2d(x_llm))
1488-
#exit()
14891485
y_pred = model.hybrid_prediction(np.atleast_2d(xs), np.atleast_2d(x_llm))[0]
14901486
except ValueError as e:
14911487
raise e # TODO: Check sequences / mutations

pypef/utils/directed_evolution.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__( # Instance attributes
123123
self.negative = negative
124124
self.de_step_counter = 0 # DE steps
125125
self.traj_counter = 0 # Trajectory counter
126+
logger.info(f"Directed evolution acceptance \"temperature\": {self.temp}")
126127

127128
def mutate_sequence(
128129
self,
@@ -216,6 +217,7 @@ def in_silico_de(self):
216217
y_traj.append(self.y_wt)
217218
s_traj.append(self.s_wt)
218219
accepted = 0
220+
wt_prediction = None
219221
logger.info(f"Step 0: WT --> {self.y_wt:.3f}")
220222
for iteration in range(self.num_iterations): # num_iterations
221223
self.de_step_counter = iteration
@@ -248,20 +250,30 @@ def in_silico_de(self):
248250
)
249251

250252
else: # hybrid modeling and prediction
253+
if wt_prediction is None:
254+
while wt_prediction is None or wt_prediction == 'skip':
255+
wt_prediction = predict_directed_evolution(
256+
encoder=self.dca_encoder,
257+
variant=self.s_wt[int(new_variant[:-1]) - 1] + new_variant[:-1] +
258+
self.s_wt[int(new_variant[:-1]) - 1], # WT, e.g. F17F
259+
variant_sequence=self.s_wt,
260+
hybrid_model_data_pkl=self.model
261+
)
251262
predictions = predict_directed_evolution(
252263
encoder=self.dca_encoder,
253264
variant=self.s_wt[int(new_variant[:-1]) - 1] + new_variant,
254265
variant_sequence=new_sequence,
255266
hybrid_model_data_pkl=self.model
256267
)
268+
print(wt_prediction)
257269
if predictions != 'skip':
258270
logger.info(f"Step {self.de_step_counter + 1}: "
259-
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> {predictions[0][0]:.3f}")
271+
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> {predictions[0][0] - wt_prediction[0][0]:.3f}")
260272
else: # skip if variant cannot be encoded by DCA-based encoding technique
261273
logger.info(f"Step {self.de_step_counter + 1}: "
262274
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> {predictions}")
263275
continue
264-
new_y, new_var = predictions[0][0], predictions[0][1] # new_var == new_variant nonetheless
276+
new_y, new_var = predictions[0][0] - wt_prediction[0][0], predictions[0][1] # new_var == new_variant nonetheless
265277
# probability function for trial sequence
266278
# The lower the fitness (y) of the new variant, the higher are the chances to get excluded
267279
with warnings.catch_warnings(): # catching Overflow warning
@@ -275,10 +287,13 @@ def in_silico_de(self):
275287
p = min(1, boltz)
276288
rand_var = random.random() # random float between 0 and 1
277289
if rand_var < p: # Metropolis-Hastings update selection criterion, else do nothing (do not accept variant)
290+
logger.info(f'Accepted variant {new_var} [current evolutionary trajectory: {v_traj}]')
278291
v_traj.append(new_var) # update the variant naming trajectory
279292
y_traj.append(new_y) # update the fitness trajectory records
280293
s_traj.append(new_sequence) # update the sequence trajectory records
281294
accepted += 1
295+
else:
296+
logger.info(f'Rejected variant {new_var} [current evolutionary trajectory: {v_traj}]')
282297

283298
self.assert_trajectory_sequences(v_traj, s_traj)
284299

0 commit comments

Comments
 (0)