4242
4343JUST_PLOT_RESULTS = False
4444
45-
4645def compute_performances (mut_data , mut_sep = ':' , start_i : int = 0 , already_tested_is : list = []):
4746 # Get cpu, gpu or mps device for training.
4847 device = get_device ()
@@ -272,7 +271,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
272271 for m in ['DCA' , 'ESM1v' , 'ProSST' , 'DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ]:
273272 temp_results [category ][f'Split { i_split } ' ].update ({m : np .nan })
274273 continue
275- #get_vram()
276274
277275 y_test_pred_dca = get_delta_e_statistical_model (x_dca_test , x_wt )
278276 temp_results [category ][f'Split { i_split } ' ].update ({'DCA' : spearmanr (y_test , y_test_pred_dca )[0 ]})
@@ -286,7 +284,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
286284
287285 for i_m , method in enumerate ([None , llm_dict_esm , llm_dict_prosst ]):
288286 m_str = ['DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ][i_m ]
289- #print('\n~~~ ' + m_str + ' ~~~')
290287 try :
291288 hm = DCALLMHybridModel (
292289 x_train_dca = np .array (x_dca_train ),
@@ -317,7 +314,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
317314 gc .collect ()
318315
319316 dt = time .time () - start_time
320- print (json .dumps (temp_results , indent = 4 ))
321317
322318 with open (out_results_csv , 'a' ) as fh :
323319 fh .write (
@@ -354,195 +350,34 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
354350
355351
356352def plot_csv_data (csv , plot_name ):
357- train_test_size_texts = []
358- df = pd .read_csv (csv , sep = ',' )
359- tested_dsets = df ['No.' ]
360- dset_dca_perfs = df ['Untrained_Performance_DCA' ]
361- dset_esm_perfs = df ['Untrained_Performance_ESM1v' ]
362- dset_prosst_perfs = df ['Untrained_Performance_ProSST' ]
363- dset_hybrid_perfs_dca_100 = df ['Hybrid_DCA_Trained_Performance_100' ]
364- dset_hybrid_perfs_dca_200 = df ['Hybrid_DCA_Trained_Performance_200' ]
365- dset_hybrid_perfs_dca_1000 = df ['Hybrid_DCA_Trained_Performance_1000' ]
366- dset_hybrid_perfs_dca_esm_100 = df ['Hybrid_DCA_ESM1v_Trained_Performance_100' ]
367- dset_hybrid_perfs_dca_esm_200 = df ['Hybrid_DCA_ESM1v_Trained_Performance_200' ]
368- dset_hybrid_perfs_dca_esm_1000 = df ['Hybrid_DCA_ESM1v_Trained_Performance_1000' ]
369- dset_hybrid_perfs_dca_prosst_100 = df ['Hybrid_DCA_ProSST_Trained_Performance_100' ]
370- dset_hybrid_perfs_dca_prosst_200 = df ['Hybrid_DCA_ProSST_Trained_Performance_200' ]
371- dset_hybrid_perfs_dca_prosst_1000 = df ['Hybrid_DCA_ProSST_Trained_Performance_1000' ]
372-
373- blue_colors = mpl .colormaps ['Blues' ](np .linspace (0.3 , 0.9 , 4 ))
374- red_colors = mpl .colormaps ['Reds' ](np .linspace (0.3 , 0.9 , 4 ))
375- green_colors = mpl .colormaps ['Greens' ](np .linspace (0.3 , 0.9 , 4 ))
376-
377- plt .figure (figsize = (80 , 12 ))
378- plt .plot (range (len (tested_dsets )), dset_dca_perfs , 'o--' , markersize = 8 ,
379- color = blue_colors [0 ], label = 'DCA (0)' )
380- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_dca_perfs )),
381- color = blue_colors [0 ], linestyle = '--' )
382- for i , (p , n_test ) in enumerate (zip (
383- dset_dca_perfs , df ['N_Y_test' ].astype ('Int64' ).to_list ())):
384- plt .text (i , 0.975 , i , color = 'black' , size = 2 )
385- plt .text (i , 0.980 , f'0' + r'$\rightarrow$' + f'{ n_test } ' , color = 'black' , size = 2 )
386- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_dca_perfs ),
387- f'{ np .nanmean (dset_dca_perfs ):.2f} ' , color = blue_colors [0 ]))
388-
389- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_100 ,
390- 'o--' , markersize = 8 , color = blue_colors [1 ], label = 'Hybrid (100)' )
391- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_100 )),
392- color = blue_colors [1 ], linestyle = '--' )
393- for i , (p , n_test ) in enumerate (zip (dset_hybrid_perfs_dca_100 , df ['N_Y_test_100' ].astype ('Int64' ).to_list ())):
394- plt .text (i , 0.985 , f'100' + r'$\rightarrow$' + f'{ n_test } ' , color = 'black' , size = 2 )
395- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_100 ),
396- f'{ np .nanmean (dset_hybrid_perfs_dca_100 ):.2f} ' , color = blue_colors [1 ]))
397-
398- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_200 ,
399- 'o--' , markersize = 8 , color = blue_colors [2 ], label = 'Hybrid (200)' )
400- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_200 )),
401- color = blue_colors [2 ], linestyle = '--' )
402- for i , (p , n_test ) in enumerate (zip (
403- dset_hybrid_perfs_dca_200 , df ['N_Y_test_200' ].astype ('Int64' ).to_list ())):
404- plt .text (i , 0.990 , f'200' + r'$\rightarrow$' + f'{ n_test } ' , color = 'black' , size = 2 )
405- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_200 ),
406- f'{ np .nanmean (dset_hybrid_perfs_dca_200 ):.2f} ' , color = blue_colors [2 ]))
407-
408- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_1000 ,
409- 'o--' , markersize = 8 , color = blue_colors [3 ], label = 'Hybrid (1000)' )
410- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_1000 )),
411- color = blue_colors [3 ], linestyle = '--' )
412- for i , (p , n_test ) in enumerate (zip (
413- dset_hybrid_perfs_dca_1000 , df ['N_Y_test_1000' ].astype ('Int64' ).to_list ())):
414- plt .text (i , 0.995 , f'1000' + r'$\rightarrow$' + f'{ n_test } ' , color = 'black' , size = 2 )
415- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_1000 ),
416- f'{ np .nanmean (dset_hybrid_perfs_dca_1000 ):.2f} ' , color = blue_colors [3 ]))
417-
418-
419- plt .plot (range (len (tested_dsets )), dset_esm_perfs ,
420- 'o--' , markersize = 8 , color = green_colors [0 ], label = 'ESM (0)' )
421- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_esm_perfs )),
422- color = green_colors [0 ], linestyle = '--' )
423- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_esm_perfs ),
424- f'{ np .nanmean (dset_esm_perfs ):.2f} ' , color = green_colors [0 ]))
425-
426- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_esm_100 ,
427- 'o--' , markersize = 8 , color = green_colors [1 ], label = 'Hybrid (100)' )
428- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_esm_100 )),
429- color = green_colors [1 ], linestyle = '--' )
430- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_esm_100 ),
431- f'{ np .nanmean (dset_hybrid_perfs_dca_esm_100 ):.2f} ' , color = green_colors [1 ]))
432-
433- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_esm_200 ,
434- 'o--' , markersize = 8 , color = green_colors [2 ], label = 'Hybrid (200)' )
435- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_esm_200 )),
436- color = green_colors [2 ], linestyle = '--' )
437- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_esm_200 ),
438- f'{ np .nanmean (dset_hybrid_perfs_dca_esm_200 ):.2f} ' , color = green_colors [2 ]))
439-
440- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_esm_1000 ,
441- 'o--' , markersize = 8 , color = green_colors [3 ], label = 'Hybrid (1000)' )
442- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_esm_1000 )),
443- color = green_colors [3 ], linestyle = '--' )
444- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_esm_1000 ),
445- f'{ np .nanmean (dset_hybrid_perfs_dca_esm_1000 ):.2f} ' , color = green_colors [3 ]))
446-
447-
448- plt .plot (range (len (tested_dsets )), dset_prosst_perfs ,
449- 'o--' , markersize = 8 , color = red_colors [0 ], label = 'ProSST (0)' )
450- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_prosst_perfs )),
451- color = red_colors [0 ], linestyle = '--' )
452- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_prosst_perfs ),
453- f'{ np .nanmean (dset_prosst_perfs ):.2f} ' , color = red_colors [0 ]))
454-
455- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_prosst_100 ,
456- 'o--' , markersize = 8 , color = red_colors [1 ], label = 'Hybrid (100)' )
457- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_prosst_100 )),
458- color = red_colors [1 ], linestyle = '--' )
459- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_prosst_100 ),
460- f'{ np .nanmean (dset_hybrid_perfs_dca_prosst_100 ):.2f} ' , color = red_colors [1 ]))
461-
462- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_prosst_200 ,
463- 'o--' , markersize = 8 , color = red_colors [2 ], label = 'Hybrid (200)' )
464- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_prosst_200 )),
465- color = red_colors [2 ], linestyle = '--' )
466- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_prosst_200 ),
467- f'{ np .nanmean (dset_hybrid_perfs_dca_prosst_200 ):.2f} ' , color = red_colors [2 ]))
468-
469- plt .plot (range (len (tested_dsets )), dset_hybrid_perfs_dca_prosst_1000 ,
470- 'o--' , markersize = 8 , color = red_colors [3 ], label = 'Hybrid (1000)' )
471- plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_hybrid_perfs_dca_prosst_1000 )),
472- color = red_colors [3 ], linestyle = '--' )
473- train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_prosst_1000 ),
474- f'{ np .nanmean (dset_hybrid_perfs_dca_prosst_1000 ):.2f} ' , color = red_colors [3 ]))
475-
476-
477- plt .grid (zorder = - 1 )
478- plt .xticks (
479- range (len (tested_dsets )),
480- ['(' + str (n ) + ') ' + name for (n , name ) in zip (tested_dsets , df ['Dataset' ].to_list ())],
481- rotation = 45 , ha = 'right'
482- )
483- plt .margins (0.01 )
484- plt .legend ()
485- plt .tight_layout ()
486- plt .ylim (0.0 , 1.0 )
487- plt .xlabel ('Tested dataset' )
488- plt .ylabel (r'Spearman $\rho$' )
489- adjust_text (train_test_size_texts , expand = (1.2 , 2 ))
490- plt .savefig (os .path .join (os .path .dirname (__file__ ), f'{ plot_name } .png' ), dpi = 300 )
491- print ('Saved file as ' + os .path .join (os .path .dirname (__file__ ), f'{ plot_name } .png' ) + '.' )
492-
493- plt .clf ()
494353 plt .figure (figsize = (24 , 12 ))
495354 sns .set_style ("whitegrid" )
496- df_ = df [[
497- 'Untrained_Performance_DCA' ,
498- 'Hybrid_DCA_Trained_Performance_100' ,
499- 'Hybrid_DCA_Trained_Performance_200' ,
500- 'Hybrid_DCA_Trained_Performance_1000' ,
501- 'Untrained_Performance_ESM1v' ,
502- 'Hybrid_DCA_ESM1v_Trained_Performance_100' ,
503- 'Hybrid_DCA_ESM1v_Trained_Performance_200' ,
504- 'Hybrid_DCA_ESM1v_Trained_Performance_1000' ,
505- 'Untrained_Performance_ProSST' ,
506- 'Hybrid_DCA_ProSST_Trained_Performance_100' ,
507- 'Hybrid_DCA_ProSST_Trained_Performance_200' ,
508- 'Hybrid_DCA_ProSST_Trained_Performance_1000' ,
509- ]]
510- print (df_ )
355+ df = pd .read_csv (csv , sep = ',' )
356+ df_mean = pd .DataFrame ()
357+ print (df )
358+ print (df .columns )
359+ for method in ['DCA_hybrid' , 'DCA+ESM1v_hybrid' , 'DCA+ProSST_hybrid' ]:
360+ for split_technique in ['Random' , 'Modulo' , 'Continuous' ]:
361+ performances = []
362+ for split in range (1 , 6 ):
363+ performances .append (df [f'{ split_technique } _Split_{ split } _{ method } ' ].to_list ())
364+ df_mean [f'{ method } _{ split_technique } _mean' ] = np .mean (performances , axis = 0 )
511365 plot = sns .violinplot (
512- df_ , saturation = 0.4 ,
513- palette = [blue_colors [0 ], blue_colors [1 ], blue_colors [2 ], blue_colors [3 ],
514- green_colors [0 ],green_colors [1 ], green_colors [2 ], green_colors [3 ],
515- red_colors [0 ], red_colors [1 ], red_colors [2 ], red_colors [3 ]]
516- )
517- plt .ylabel (r'Spearmanr $\rho$' )
518- sns .swarmplot (df_ , color = 'black' )
519- dset_perfs = [
520- dset_dca_perfs ,
521- dset_hybrid_perfs_dca_100 ,
522- dset_hybrid_perfs_dca_200 ,
523- dset_hybrid_perfs_dca_1000 ,
524- dset_esm_perfs ,
525- dset_hybrid_perfs_dca_esm_100 ,
526- dset_hybrid_perfs_dca_esm_200 ,
527- dset_hybrid_perfs_dca_esm_1000 ,
528- dset_prosst_perfs ,
529- dset_hybrid_perfs_dca_prosst_100 ,
530- dset_hybrid_perfs_dca_prosst_200 ,
531- dset_hybrid_perfs_dca_prosst_1000
532- ]
533- for n in range (0 , len (dset_perfs )):
366+ df_mean , saturation = 0.4
367+ )
368+ sns .swarmplot (df_mean )
369+ for n in range (0 , df_mean .shape [1 ]):
534370 plt .text (
535371 n + 0.15 , - 0.075 ,
536- r'$\overline{\rho}=$' + f'{ np .nanmean (dset_perfs [ n ]):.3f} \n '
537- + r'$N_\mathrm{Datasets}=$' + f'{ np .count_nonzero (~ np .isnan (np .array (dset_perfs )[ n ] ))} '
372+ r'$\overline{\rho}=$' + f'{ np .nanmean (df_mean . iloc [:, n ]):.3f} \n '
373+ + r'$N_\mathrm{Datasets}=$' + f'{ np .count_nonzero (~ np .isnan (np .array (df_mean . iloc [:, n ]) ))} '
538374 )
539375 plot .set_xticks (range (len (plot .get_xticklabels ())))
540376 plot .set_xticklabels (plot .get_xticklabels (), rotation = 45 , horizontalalignment = 'right' )
541377 plt .ylim (- 0.09 , 1.09 )
542378 plt .margins (0.05 )
543379 plt .tight_layout ()
544- plt .savefig (os .path .join (os .path .dirname (__file__ ), f'{ plot_name } _violin.png' ), dpi = 300 )
545- print ('Saved file as ' + os .path .join (os .path .dirname (__file__ ), f'{ plot_name } _violin.png' ) + '.' )
380+ plt .show ()
546381
547382
548383if __name__ == '__main__' :
0 commit comments