Skip to content

Commit 778e761

Browse files
committed
Update DirectedEvolution class
1 parent 6ebfcac commit 778e761

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,4 @@ datasets/AVGFP/model_saves/*
420420
datasets/AVGFP/Pickles/*
421421
datasets/AVGFP/DCA_Hybrid_Model_Performance_ESM1v_no_ML.png
422422
datasets/AVGFP/DCA_Hybrid_Model_Performance_ProSST_no_ML.png
423+
datasets/AVGFP/HYBRIDgremlinesm_DE_trajectories.png

pypef/hybrid/hybrid_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ def performance_ls_ts(
12101210
)
12111211
model_name = f'HYBRID{model_type.lower()}{llm.lower()}'
12121212
y_test_pred = hybrid_model.hybrid_prediction(np.array(x_test), x_llm_test)
1213-
print(f'Hybrid performance: {spearmanr(y_test, y_test_pred)}')
1213+
print(f'Hybrid performance: {spearmanr(y_test, y_test_pred)[0]:.3f} N={len(y_test)}')
12141214
save_model_to_dict_pickle(hybrid_model, model_name)
12151215

12161216
elif (

pypef/utils/directed_evolution.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -240,22 +240,24 @@ def in_silico_de(self):
240240
if wt_prediction is None or wt_prediction == 'skip':
241241
wt_prediction = 'skip'
242242
while wt_prediction == 'skip':
243+
rand_pos = random.randint(0, len(self.s_wt))
244+
wt_mut = self.s_wt[rand_pos] + str(rand_pos) + self.s_wt[rand_pos]
243245
wt_prediction = predict( # AAidx, OneHot, or DCA-based pure ML prediction
244246
path=self.path,
245247
model=self.model,
246248
encoding=self.encoding,
247-
variants=np.atleast_1d(self.s_wt[int(new_variant[:-1]) - 1] + new_variant[:-1] +
248-
self.s_wt[int(new_variant[:-1]) - 1]),
249+
variants=np.atleast_1d(wt_mut),
249250
sequences=np.atleast_1d(self.s_wt),
250251
no_fft=self.no_fft,
251252
couplings_file=self.dca_encoder
252253
)
253-
logger.info(
254-
f"Step {self.de_step_counter}: "
255-
f"{self.s_wt[int(new_variant[:-1]) - 1] + new_variant[:-1] + self.s_wt[int(new_variant[:-1]) - 1]} --> "
256-
f"{wt_prediction[0][0]} WT relative fitness: {wt_prediction[0][0] - wt_prediction[0][0] + add_epsilon:.3f}"
257-
)
258-
y_traj[0] = wt_prediction[0][0]
254+
if self.de_step_counter == 0:
255+
logger.info(
256+
f"Step {self.de_step_counter}: "
257+
f"WT ({wt_mut}) --> {wt_prediction[0][0]:.3f} WT relative fitness: "
258+
f"{wt_prediction[0][0] - wt_prediction[0][0] + add_epsilon:.3f}"
259+
)
260+
y_traj[0] = wt_prediction[0][0]
259261
predictions = predict( # AAidx, OneHot, or DCA-based pure ML prediction
260262
path=self.path,
261263
model=self.model,
@@ -270,32 +272,33 @@ def in_silico_de(self):
270272
if wt_prediction is None or wt_prediction == 'skip':
271273
wt_prediction = 'skip'
272274
while wt_prediction == 'skip':
275+
rand_pos = random.randint(0, len(self.s_wt))
276+
wt_mut = self.s_wt[rand_pos] + str(rand_pos) + self.s_wt[rand_pos]
273277
wt_prediction = predict_directed_evolution(
274278
encoder=self.dca_encoder,
275279
variant=self.s_wt[int(new_variant[:-1]) - 1] + new_variant[:-1] +
276280
self.s_wt[int(new_variant[:-1]) - 1], # WT, e.g. F17F
277281
variant_sequence=self.s_wt,
278282
hybrid_model_data_pkl=self.model
279283
)
280-
logger.info(
281-
f"Step {self.de_step_counter}: "
282-
f"WT ({self.s_wt[int(new_variant[:-1]) - 1] + new_variant[:-1] + self.s_wt[int(new_variant[:-1]) - 1]}) --> "
283-
f"{wt_prediction[0][0]} WT relative fitness: {wt_prediction[0][0] - wt_prediction[0][0] + add_epsilon:.3f}"
284-
)
284+
if self.de_step_counter == 0:
285+
logger.info(
286+
f"Step {self.de_step_counter}: "
287+
f"WT ({wt_mut}) --> {wt_prediction[0][0]:.3f} WT relative fitness: "
288+
f"{wt_prediction[0][0] - wt_prediction[0][0] + add_epsilon:.3f}"
289+
)
290+
# add_epsilon = 0.01 * abs(wt_prediction[0][0]) # Adding 1% to prediction for hybrid modeling!
285291
y_traj[0] = wt_prediction[0][0] - wt_prediction[0][0]
286-
#add_epsilon = 0.01 * abs(wt_prediction[0][0]) # Adding 1% to prediction for hybrid modeling!
287-
y_traj[0] = wt_prediction[0][0] - wt_prediction[0][0]
288292
predictions = predict_directed_evolution(
289293
encoder=self.dca_encoder,
290294
variant=self.s_wt[int(new_variant[:-1]) - 1] + new_variant,
291295
variant_sequence=new_sequence,
292296
hybrid_model_data_pkl=self.model
293297
)
294-
print('PREDICTIONS:', predictions)
295298
if predictions != 'skip':
296299
logger.info(f"Step {self.de_step_counter + 1}: "
297300
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> "
298-
f"{predictions[0][0]} WT relative fitness: {predictions[0][0] - wt_prediction[0][0] + add_epsilon:.3f}")
301+
f"{predictions[0][0]:.3f} WT relative fitness: {predictions[0][0] - wt_prediction[0][0] + add_epsilon:.3f}")
299302
else: # skip if variant cannot be encoded by DCA-based encoding technique
300303
logger.info(f"Step {self.de_step_counter + 1}: "
301304
f"{self.s_wt[int(new_variant[:-1]) - 1]}{new_variant} --> {predictions}")
@@ -319,9 +322,9 @@ def in_silico_de(self):
319322
y_traj.append(new_y) # update the fitness trajectory records
320323
s_traj.append(new_sequence) # update the sequence trajectory records
321324
accepted += 1
322-
logger.info(f'Accepted variant {new_var} [current evolutionary trajectory: {v_traj}]')
325+
logger.info(f'Accepted variant {new_var} (current evolutionary trajectory: {v_traj})')
323326
else:
324-
logger.info(f'Rejected variant {new_var} [current evolutionary trajectory: {v_traj}]')
327+
logger.info(f'Rejected variant {new_var} (current evolutionary trajectory: {v_traj})')
325328

326329
self.assert_trajectory_sequences(v_traj, s_traj)
327330

0 commit comments

Comments
 (0)