diff --git a/examples/eval_scireasoner.py b/examples/eval_scireasoner.py new file mode 100644 index 000000000..d493a581f --- /dev/null +++ b/examples/eval_scireasoner.py @@ -0,0 +1,157 @@ +from mmengine.config import read_base + +with read_base(): + # scireasoner + from opencompass.configs.datasets.SciReasoner.scireasoner_gen import scireasoner_datasets_full, scireasoner_datasets_mini + +# # full set summarizer +# summarizer = dict( +# dataset_abbrs=[ +# ['SciReasoner-bio_instruction-antibody_antigen', 'MCC'], ['SciReasoner-bio_instruction-rna_protein_interaction', 'MCC'], ['SciReasoner-bio_instruction-emp', 'MCC'], +# ['SciReasoner-bio_instruction-enhancer_activity', 'PCC'], ['SciReasoner-bio_instruction-tf_m', 'MCC'], ['SciReasoner-bio_instruction-Isoform', 'R2'], ['SciReasoner-bio_instruction-Modification', 'AUC'], +# ['SciReasoner-bio_instruction-MeanRibosomeLoading', 'R2'], ['SciReasoner-bio_instruction-ProgrammableRNASwitches', 'R2'], ['SciReasoner-bio_instruction-CRISPROnTarget', 'spearman'], +# ['SciReasoner-bio_instruction-promoter_enhancer_interaction', 'MCC'], ['SciReasoner-bio_instruction-sirnaEfficiency', 'mixed_score'], ['SciReasoner-bio_instruction-cpd', 'MCC'], +# ['SciReasoner-bio_instruction-pd', 'MCC'], ['SciReasoner-bio_instruction-tf_h', 'MCC'], +# ['SciReasoner-Gue_cpd-prom_core_all', 'matthews_correlation_all'], +# ['SciReasoner-Gue_cpd-prom_core_notata', 'matthews_correlation_all'], +# ['SciReasoner-Gue_cpd-prom_core_tata', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_all', 'matthews_correlation_all'], +# ['SciReasoner-Gue_pd-prom_300_notata', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_tata', 'matthews_correlation_all'], +# ['SciReasoner-Gue_tf-h-0', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-1', 'matthews_correlation_all'], +# ['SciReasoner-Gue_tf-h-2', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-3', 'matthews_correlation_all'], +# ['SciReasoner-Gue_tf-h-4', 'matthews_correlation_all'], ['SciReasoner-smol_forward_synthesis', 'top1_exact_match'], +# ['SciReasoner-smol_retrosynthesis', 'top1_exact_match'], ['SciReasoner-smol_molecule_captioning', 'meteor_score'], +# ['SciReasoner-smol_molecule_generation', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-i2f', 'top1_ele_match'], +# ['SciReasoner-smol_name_conversion-i2s', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-s2f', 'top1_ele_match'], +# ['SciReasoner-smol_name_conversion-s2i', 'top1_split_match'], ['SciReasoner-smol_property_prediction-esol', 'RMSE'], +# ['SciReasoner-smol_property_prediction-lipo', 'RMSE'], ['SciReasoner-smol_property_prediction-bbbp', 'accuracy'], +# ['SciReasoner-smol_property_prediction-clintox', 'accuracy'], ['SciReasoner-smol_property_prediction-hiv', 'accuracy'], +# ['SciReasoner-smol_property_prediction-sider', 'accuracy'], ['SciReasoner-retrosynthesis_USPTO_50K', 'Top-1 Accuracy'], +# ['SciReasoner-LLM4Mat_MP_IsStable', 'AUC'], ['SciReasoner-LLM4Mat_MP_IsGapDirect', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_IsDirect', 'AUC'], +# ['SciReasoner-LLM4Mat_SNUMAT_IsDirect_HSE', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_SOC', 'AUC'], ['SciReasoner-LLM4Mat_MP_FEPA', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_MP_Bandgap', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_EPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Efermi', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_MP_Density', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_DensityAtomic', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Volume', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_OPT', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_TotEn', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_MBJ', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Kv', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_Gv', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_SLME', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Spillage', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_Epsx_OPT', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Dielectric_DFPT', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_dij', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_eij', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_MaxEFG', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_ExfEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_AvgMe', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_nSeebeck', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_nPF', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_pSeebeck', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISDFT_pPF', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA_Optical', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE_Optical', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_GNoME_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_DEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Bandgap', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_GNoME_TotEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Volume', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Density', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_hMOF_MaxCO2', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_MinCO2', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_LCD', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_hMOF_PLD', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_VoidFraction', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_SA_m2g', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_hMOF_SA_m2cm3', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SciReasoner-LLM4Mat_Cantor_HEA_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_EPA', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_Cantor_HEA_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_VPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_TotEn', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_QMOF_Bandgap', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_LCD', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_PLD', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISQETB_EPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_IndirBandgap', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_JARVISQETB_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_TotEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OQMD_Bandgap', 'MAD/MAE'], +# ['SciReasoner-LLM4Mat_OQMD_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OMDB_Bandgap', 'MAD/MAE'], +# ['SciReasoner-composition_to_material_generation', 'smact_validity_ratio_in_all_%'], +# ['SciReasoner-bulk_modulus_to_material_generation', 'smact_validity_ratio_in_all_%'], +# ['SciReasoner-mol_instruction_chemical_disease_interaction_extraction', 'f1'], +# ['SciReasoner-mol_instruction_chemical_entity_recognition', 'f1'], +# ['SciReasoner-mol_instruction_chemical_protein_interaction_extraction', 'f1'], +# ['SciReasoner-mol_instruction_multi_choice_question', 'accuracy'], +# ['SciReasoner-mol_instruction_open_question', 'bert_score'], +# ['SciReasoner-mol_instruction_true_or_false_question', 'accuracy'], +# ['SciReasoner-mol_instruction_property_prediction_str', 'mae'], +# ['SciReasoner-mol_instruction_description_guided_molecule_design', 'exact_match_score'], +# ['SciReasoner-mol_instruction_forward_reaction_prediction', 'exact_match_score'], +# ['SciReasoner-mol_instruction_retrosynthesis', 'exact_match_score'], +# ['SciReasoner-mol_instruction_reagent_prediction', 'exact_match_score'], +# ['SciReasoner-mol_instruction_molecular_description_generation', 'rougeL'], +# ['SciReasoner-mol_instruction_catalytic_activity', 'rougeL'], ['SciReasoner-mol_instruction_domain_motif', 'rougeL'], +# ['SciReasoner-mol_instruction_general_function', 'rougeL'], ['SciReasoner-mol_instruction_protein_function', 'rougeL'], +# ['SciReasoner-mol_instruction_protein_design', 'Max SW score'], ['SciReasoner-Opi_EC_number_CLEAN_EC_number_new', 'Accuracy'], +# ['SciReasoner-Opi_EC_number_CLEAN_EC_number_price', 'Accuracy'], ['SciReasoner-Opi_Fold_type_fold_type', 'Accuracy'], +# ['SciReasoner-Opi_Function_CASPSimilarSeq_function', 'ROUGE-L'], ['SciReasoner-Opi_Function_IDFilterSeq_function', 'ROUGE-L'], +# ['SciReasoner-Opi_Function_UniProtSeq_function', 'ROUGE-L'], ['SciReasoner-Opi_gName2Cancer_gene_name_to_cancer', 'F1 Score'], +# ['SciReasoner-Opi_GO_CASPSimilarSeq_go', 'F1 Score'], ['SciReasoner-Opi_GO_IDFilterSeq_go', 'F1 Score'], ['GO_UniProtSeq_go', 'F1 Score'], +# ['SciReasoner-Opi_gSymbol2Cancer_gene_symbol_to_cancer', 'F1 Score'], ['SciReasoner-Opi_gSymbol2Tissue_gene_symbol_to_tissue', 'F1 Score'], +# ['SciReasoner-Opi_Keywords_CASPSimilarSeq_keywords', 'F1 Score'], ['SciReasoner-Opi_Keywords_IDFilterSeq_keywords', 'F1 Score'], +# ['SciReasoner-Opi_Keywords_UniProtSeq_keywords', 'F1 Score'], ['SciReasoner-Opi_Subcellular_localization_subcell_loc', 'Accuracy'], +# ['SciReasoner-PEER_solubility', 'accuracy'], ['SciReasoner-PEER_stability', 'accuracy'], ['SciReasoner-PEER_human_ppi', 'accuracy'], ['SciReasoner-PEER_yeast_ppi', 'accuracy'], +# ['SciReasoner-unconditional_material_generation', 'smact_validity_ratio_in_all'], +# ['SciReasoner-unconditional_RNA_generation', 'average_mfe'], ['SciReasoner-unconditional_protein_generation', 'valid_rate'], +# ['SciReasoner-unconditional_molecule_generation', 'validity'] +# ] +# ) + +# mini set summarizer +summarizer = dict( + dataset_abbrs=[ + ['SciReasoner-bio_instruction-antibody_antigen-mini', 'MCC'], ['SciReasoner-bio_instruction-rna_protein_interaction-mini', 'MCC'], ['SciReasoner-bio_instruction-emp-mini', 'MCC'], + ['SciReasoner-bio_instruction-enhancer_activity-mini', 'PCC'], ['SciReasoner-bio_instruction-tf_m-mini', 'MCC'], ['SciReasoner-bio_instruction-Isoform-mini', 'R2'], ['SciReasoner-bio_instruction-Modification-mini', 'AUC'], + ['SciReasoner-bio_instruction-MeanRibosomeLoading-mini', 'R2'], ['SciReasoner-bio_instruction-ProgrammableRNASwitches-mini', 'R2'], ['SciReasoner-bio_instruction-CRISPROnTarget-mini', 'spearman'], + ['SciReasoner-bio_instruction-promoter_enhancer_interaction-mini', 'MCC'], ['SciReasoner-bio_instruction-sirnaEfficiency-mini', 'mixed_score'], ['SciReasoner-bio_instruction-cpd-mini', 'MCC'], + ['SciReasoner-bio_instruction-pd-mini', 'MCC'], ['SciReasoner-bio_instruction-tf_h-mini', 'MCC'], + ['SciReasoner-Gue_cpd-prom_core_all-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_cpd-prom_core_notata-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_cpd-prom_core_tata-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_all-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_pd-prom_300_notata-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_tata-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_tf-h-0-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-1-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_tf-h-2-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-3-mini', 'matthews_correlation_all'], + ['SciReasoner-Gue_tf-h-4-mini', 'matthews_correlation_all'], ['SciReasoner-smol_forward_synthesis-mini', 'top1_exact_match'], + ['SciReasoner-smol_retrosynthesis-mini', 'top1_exact_match'], ['SciReasoner-smol_molecule_captioning-mini', 'meteor_score'], + ['SciReasoner-smol_molecule_generation-mini', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-i2f-mini', 'top1_ele_match'], + ['SciReasoner-smol_name_conversion-i2s-mini', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-s2f-mini', 'top1_ele_match'], + ['SciReasoner-smol_name_conversion-s2i-mini', 'top1_split_match'], ['SciReasoner-smol_property_prediction-esol-mini', 'RMSE'], + ['SciReasoner-smol_property_prediction-lipo-mini', 'RMSE'], ['SciReasoner-smol_property_prediction-bbbp-mini', 'accuracy'], + ['SciReasoner-smol_property_prediction-clintox-mini', 'accuracy'], ['SciReasoner-smol_property_prediction-hiv-mini', 'accuracy'], + ['SciReasoner-smol_property_prediction-sider-mini', 'accuracy'], ['SciReasoner-retrosynthesis_USPTO_50K-mini', 'Top-1 Accuracy'], + ['SciReasoner-LLM4Mat_MP_IsStable-mini', 'AUC'], ['SciReasoner-LLM4Mat_MP_IsGapDirect-mini', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_IsDirect-mini', 'AUC'], + ['SciReasoner-LLM4Mat_SNUMAT_IsDirect_HSE-mini', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_SOC-mini', 'AUC'], ['SciReasoner-LLM4Mat_MP_FEPA-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_MP_Bandgap-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_EPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Efermi-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_MP_Density-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_DensityAtomic-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Volume-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_OPT-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_TotEn-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_MBJ-mini', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Kv-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_Gv-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_SLME-mini', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Spillage-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_Epsx_OPT-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Dielectric_DFPT-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_dij-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_eij-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_MaxEFG-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_ExfEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_AvgMe-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_nSeebeck-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_nPF-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_pSeebeck-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISDFT_pPF-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA_Optical-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE_Optical-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_GNoME_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_DEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Bandgap-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_GNoME_TotEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Volume-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Density-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_hMOF_MaxCO2-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_MinCO2-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_LCD-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_hMOF_PLD-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_VoidFraction-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_SA_m2g-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_hMOF_SA_m2cm3-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SciReasoner-LLM4Mat_Cantor_HEA_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_EPA-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_Cantor_HEA_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_VPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_TotEn-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_QMOF_Bandgap-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_LCD-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_PLD-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISQETB_EPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_IndirBandgap-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_JARVISQETB_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_TotEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OQMD_Bandgap-mini', 'MAD/MAE'], + ['SciReasoner-LLM4Mat_OQMD_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OMDB_Bandgap-mini', 'MAD/MAE'], + ['SciReasoner-composition_to_material_generation-mini', 'smact_validity_ratio_in_all_%'], + ['SciReasoner-bulk_modulus_to_material_generation-mini', 'smact_validity_ratio_in_all_%'], + ['SciReasoner-mol_instruction_chemical_disease_interaction_extraction-mini', 'f1'], + ['SciReasoner-mol_instruction_chemical_entity_recognition-mini', 'f1'], + ['SciReasoner-mol_instruction_chemical_protein_interaction_extraction-mini', 'f1'], + ['SciReasoner-mol_instruction_multi_choice_question-mini', 'accuracy'], + ['SciReasoner-mol_instruction_open_question-mini', 'bert_score'], + ['SciReasoner-mol_instruction_true_or_false_question-mini', 'accuracy'], + ['SciReasoner-mol_instruction_property_prediction_str-mini', 'mae'], + ['SciReasoner-mol_instruction_description_guided_molecule_design-mini', 'exact_match_score'], + ['SciReasoner-mol_instruction_forward_reaction_prediction-mini', 'exact_match_score'], + ['SciReasoner-mol_instruction_retrosynthesis-mini', 'exact_match_score'], + ['SciReasoner-mol_instruction_reagent_prediction-mini', 'exact_match_score'], + ['SciReasoner-mol_instruction_molecular_description_generation-mini', 'rougeL'], + ['SciReasoner-mol_instruction_catalytic_activity-mini', 'rougeL'], ['SciReasoner-mol_instruction_domain_motif-mini', 'rougeL'], + ['SciReasoner-mol_instruction_general_function-mini', 'rougeL'], ['SciReasoner-mol_instruction_protein_function-mini', 'rougeL'], + ['SciReasoner-mol_instruction_protein_design-mini', 'Max SW score'], ['SciReasoner-Opi_EC_number_CLEAN_EC_number_new-mini', 'Accuracy'], + ['SciReasoner-Opi_EC_number_CLEAN_EC_number_price-mini', 'Accuracy'], ['SciReasoner-Opi_Fold_type_fold_type-mini', 'Accuracy'], + ['SciReasoner-Opi_Function_CASPSimilarSeq_function-mini', 'ROUGE-L'], ['SciReasoner-Opi_Function_IDFilterSeq_function-mini', 'ROUGE-L'], + ['SciReasoner-Opi_Function_UniProtSeq_function-mini', 'ROUGE-L'], ['SciReasoner-Opi_gName2Cancer_gene_name_to_cancer-mini', 'F1 Score'], + ['SciReasoner-Opi_GO_CASPSimilarSeq_go-mini', 'F1 Score'], ['SciReasoner-Opi_GO_IDFilterSeq_go-mini', 'F1 Score'], ['GO_UniProtSeq_go-mini', 'F1 Score'], + ['SciReasoner-Opi_gSymbol2Cancer_gene_symbol_to_cancer-mini', 'F1 Score'], ['SciReasoner-Opi_gSymbol2Tissue_gene_symbol_to_tissue-mini', 'F1 Score'], + ['SciReasoner-Opi_Keywords_CASPSimilarSeq_keywords-mini', 'F1 Score'], ['SciReasoner-Opi_Keywords_IDFilterSeq_keywords-mini', 'F1 Score'], + ['SciReasoner-Opi_Keywords_UniProtSeq_keywords-mini', 'F1 Score'], ['SciReasoner-Opi_Subcellular_localization_subcell_loc-mini', 'Accuracy'], + ['SciReasoner-PEER_solubility-mini', 'accuracy'], ['SciReasoner-PEER_stability-mini', 'accuracy'], ['SciReasoner-PEER_human_ppi-mini', 'accuracy'], ['SciReasoner-PEER_yeast_ppi-mini', 'accuracy'], + ['SciReasoner-unconditional_material_generation-mini', 'smact_validity_ratio_in_all'], + ['SciReasoner-unconditional_RNA_generation-mini', 'average_mfe'], ['SciReasoner-unconditional_protein_generation-mini', 'valid_rate'], + ['SciReasoner-unconditional_molecule_generation-mini', 'validity'] + ] +) \ No newline at end of file diff --git a/opencompass/configs/datasets/SciReasoner/GUE_gen.py b/opencompass/configs/datasets/SciReasoner/GUE_gen.py new file mode 100644 index 000000000..77deab5da --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/GUE_gen.py @@ -0,0 +1,77 @@ +from opencompass.datasets import ( + GUE_Dataset, + GUE_Evaluator, + GUE_postprocessor +) +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_retriever import ZeroRetriever + +GUE_sub_tasks = [ + 'cpd-prom_core_all', + 'cpd-prom_core_notata', + 'cpd-prom_core_tata', + 'pd-prom_300_all', + 'pd-prom_300_notata', + 'pd-prom_300_tata', + 'tf-h-0', + 'tf-h-1', + 'tf-h-2', + 'tf-h-3', + 'tf-h-4', +] + +GUE_reader_cfg = dict(input_columns=['input'], output_column='output') + +GUE_datasets = [] +mini_GUE_datasets = [] + +for name in GUE_sub_tasks: + + GUE_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='{input}'), + ] + ), + ), + retriever=dict( + type=ZeroRetriever + ), + inferencer=dict(type=GenInferencer), + ) + + GUE_eval_cfg = dict( + evaluator=dict( + type=GUE_Evaluator + ), + pred_role='BOT', + pred_postprocessor=dict(type=GUE_postprocessor), + dataset_postprocessor=dict(type=GUE_postprocessor), + ) + + GUE_datasets.append( + dict( + abbr=f'SciReasoner-Gue_{name}', + type=GUE_Dataset, + path='opencompass/SciReasoner-GUE', + task=name, + reader_cfg=GUE_reader_cfg, + infer_cfg=GUE_infer_cfg, + eval_cfg=GUE_eval_cfg, + ) + ) + mini_GUE_datasets.append( + dict( + abbr=f'SciReasoner-Gue_{name}-mini', + type=GUE_Dataset, + path='opencompass/SciReasoner-GUE', + task=name, + mini_set=True, + reader_cfg=GUE_reader_cfg, + infer_cfg=GUE_infer_cfg, + eval_cfg=GUE_eval_cfg, + ) + ) diff --git a/opencompass/configs/datasets/SciReasoner/LLM4Mat_gen.py b/opencompass/configs/datasets/SciReasoner/LLM4Mat_gen.py new file mode 100644 index 000000000..323538f29 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/LLM4Mat_gen.py @@ -0,0 +1,290 @@ +from opencompass.datasets import ( + LLM4MatDataset, + LLM4Mat_Evaluator, + LLM4Mat_postprocessor +) +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_retriever import ZeroRetriever + +LLM4Mat_sub_tasks = \ + {'MP_FEPA': {'property': 'formation_energy_per_atom', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_Bandgap': {'property': 'band_gap', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_EPA': {'property': 'GGA-PBE-based_energy_per_atom', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_Ehull': {'property': 'energy_above_hull', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_Efermi': {'property': 'efermi', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_Density': {'property': 'density', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_DensityAtomic': {'property': 'density_atomic', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_Volume': {'property': 'volume', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_IsStable': {'property': 'is_stable', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'MP_IsGapDirect': {'property': 'is_gap_direct', + 'test_path': 'mp/test/data.json', + 'train_path': 'mp/dev/data.json'}, + 'JARVISDFT_FEPA': {'property': 'formation_energy_peratom', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Bandgap_OPT': {'property': 'optb88vdw_bandgap', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_TotEn': {'property': 'optb88vdw_total_energy', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Ehull': {'property': 'ehull', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Bandgap_MBJ': {'property': 'mbj_bandgap', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Kv': {'property': 'bulk_modulus_kv', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Gv': {'property': 'shear_modulus_gv', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_SLME': {'property': 'slme', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Spillage': {'property': 'spillage', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Epsx_OPT': {'property': 'mepsx', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Dielectric_DFPT': {'property': 'dfpt_piezo_max_dielectric', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Max_Piezo_dij': {'property': 'dfpt_piezo_max_dij', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_Max_Piezo_eij': {'property': 'dfpt_piezo_max_eij', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_MaxEFG': {'property': 'max_efg', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_ExfEn': {'property': 'exfoliation_energy', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_AvgMe': {'property': 'avg_elec_mass', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_nSeebeck': {'property': 'n-Seebeck', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_nPF': {'property': 'n-powerfact', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_pSeebeck': {'property': 'p-Seebeck', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'JARVISDFT_pPF': {'property': 'p-powerfact', + 'test_path': 'jarvis_dft/test/data.json', + 'train_path': 'jarvis_dft/dev/data.json'}, + 'SNUMAT_Bandgap_GGA': {'property': 'Band_gap_GGA', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_Bandgap_HSE': {'property': 'Band_gap_HSE', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_Bandgap_GGA_Optical': {'property': 'Band_gap_GGA_optical', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_Bandgap_HSE_Optical': {'property': 'Band_gap_HSE_optical', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_IsDirect': {'property': 'Direct_or_indirect', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_IsDirect_HSE': {'property': 'Direct_or_indirect_HSE', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'SNUMAT_SOC': {'property': 'SOC', + 'test_path': 'snumat/test/data.json', + 'train_path': 'snumat/dev/data.json'}, + 'GNoME_FEPA': {'property': 'Formation_Energy_Per_Atom', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'GNoME_DEPA': {'property': 'Decomposition_Energy_Per_Atom', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'GNoME_Bandgap': {'property': 'Bandgap', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'GNoME_TotEn': {'property': 'Corrected_Energy', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'GNoME_Volume': {'property': 'Volume', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'GNoME_Density': {'property': 'Density', + 'test_path': 'gnome/test/data.json', + 'train_path': 'gnome/dev/data.json'}, + 'hMOF_MaxCO2': {'property': 'max_co2_adsp', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_MinCO2': {'property': 'min_co2_adsp', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_LCD': {'property': 'lcd', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_PLD': {'property': 'pld', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_VoidFraction': {'property': 'void_fraction', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_SA_m2g': {'property': 'surface_area_m2g', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'hMOF_SA_m2cm3': {'property': 'surface_area_m2cm3', + 'test_path': 'hmof/test/data.json', + 'train_path': 'hmof/dev/data.json'}, + 'Cantor_HEA_FEPA': {'property': 'Ef_per_atom', + 'test_path': 'cantor_hea/test/data.json', + 'train_path': 'cantor_hea/dev/data.json'}, + 'Cantor_HEA_EPA': {'property': 'e_per_atom', + 'test_path': 'cantor_hea/test/data.json', + 'train_path': 'cantor_hea/dev/data.json'}, + 'Cantor_HEA_Ehull': {'property': 'e_above_hull', + 'test_path': 'cantor_hea/test/data.json', + 'train_path': 'cantor_hea/dev/data.json'}, + 'Cantor_HEA_VPA': {'property': 'volume_per_atom', + 'test_path': 'cantor_hea/test/data.json', + 'train_path': 'cantor_hea/dev/data.json'}, + 'QMOF_TotEn': {'property': 'energy_total', + 'test_path': 'qmof/test/data.json', + 'train_path': 'qmof/dev/data.json'}, + 'QMOF_Bandgap': {'property': 'bandgap', + 'test_path': 'qmof/test/data.json', + 'train_path': 'qmof/dev/data.json'}, + 'QMOF_LCD': {'property': 'lcd', + 'test_path': 'qmof/test/data.json', + 'train_path': 'qmof/dev/data.json'}, + 'QMOF_PLD': {'property': 'pld', + 'test_path': 'qmof/test/data.json', + 'train_path': 'qmof/dev/data.json'}, + 'JARVISQETB_EPA': {'property': 'TB-based_energy_per_atom', + 'test_path': 'jarvis_qetb/test/data.json', + 'train_path': 'jarvis_qetb/dev/data.json'}, + 'JARVISQETB_IndirBandgap': {'property': 'indir_gap', + 'test_path': 'jarvis_qetb/test/data.json', + 'train_path': 'jarvis_qetb/dev/data.json'}, + 'JARVISQETB_FEPA': {'property': 'f_enp', + 'test_path': 'jarvis_qetb/test/data.json', + 'train_path': 'jarvis_qetb/dev/data.json'}, + 'JARVISQETB_TotEn': {'property': 'final_energy', + 'test_path': 'jarvis_qetb/test/data.json', + 'train_path': 'jarvis_qetb/dev/data.json'}, + 'OQMD_Bandgap': {'property': 'bandgap', + 'test_path': 'oqmd/test/data.json', + 'train_path': 'oqmd/dev/data.json'}, + 'OQMD_FEPA': {'property': 'e_form', + 'test_path': 'oqmd/test/data.json', + 'train_path': 'oqmd/dev/data.json'}, + 'OMDB_Bandgap': {'property': 'bandgap', + 'test_path': 'omdb/test/data.json', + 'train_path': 'omdb/dev/data.json'} + } + +non_numeric_props_options = { + 'Direct_or_indirect': ['Direct', 'Indirect'], + 'Direct_or_indirect_HSE': ['Direct', 'Indirect'], + 'SOC': [True, False], + 'is_gap_direct': [True, False], + 'is_stable': [True, False], +} + +LLM4Mat_reader_cfg = dict(input_columns=['input'], output_column='output') + +LLM4Mat_datasets = [] +mini_LLM4Mat_datasets = [] + + +for name, info in LLM4Mat_sub_tasks.items(): + prop = info['property'] + test_path = info['test_path'] + train_path = info['train_path'] + + if prop in non_numeric_props_options: + options = non_numeric_props_options[prop] + if all(isinstance(x, bool) for x in options): + options_str = 'True/False' + else: + options_str = '/'.join(str(x) for x in options) + + prompt_template = dict( + round=[ + dict(role='HUMAN', prompt=f'{{input}}'), + ] + ) + else: + prompt_template = dict( + round=[ + dict(role='HUMAN', prompt='{input}'), + ] + ) + + LLM4Mat_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=prompt_template, + ), + retriever=dict( + type=ZeroRetriever + ), + inferencer=dict(type=GenInferencer), + ) + + LLM4Mat_eval_cfg = dict( + evaluator=dict(type=LLM4Mat_Evaluator), + pred_role='BOT', + pred_postprocessor=dict(type=LLM4Mat_postprocessor, property=prop), + dataset_postprocessor=dict(type=LLM4Mat_postprocessor, property=prop), + ) + + LLM4Mat_datasets.append( + dict( + abbr=f'SciReasoner-LLM4Mat_{name}', + type=LLM4MatDataset, + path='opencompass/SciReasoner-LLM4Mat', + train_path=train_path, + test_path=test_path, + property=prop, + reader_cfg=LLM4Mat_reader_cfg, + infer_cfg=LLM4Mat_infer_cfg, + eval_cfg=LLM4Mat_eval_cfg, + ) + ) + mini_LLM4Mat_datasets.append( + dict( + abbr=f'SciReasoner-LLM4Mat_{name}-mini', + type=LLM4MatDataset, + path='opencompass/SciReasoner-LLM4Mat', + train_path=train_path, + test_path=test_path, + property=prop, + mini_set=True, + reader_cfg=LLM4Mat_reader_cfg, + infer_cfg=LLM4Mat_infer_cfg, + eval_cfg=LLM4Mat_eval_cfg, + ) + ) diff --git a/opencompass/configs/datasets/SciReasoner/UMG.py b/opencompass/configs/datasets/SciReasoner/UMG.py new file mode 100644 index 000000000..e3edfb536 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/UMG.py @@ -0,0 +1,49 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import UMG_Dataset, UMG_Evaluator + +INFER_TEMPLATE = '''Generate a molecule with ''' + +reader_cfg = dict(input_columns=['input'], output_column='output') + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{input}', + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +eval_cfg = dict( + evaluator=dict( + type=UMG_Evaluator, + ), +) + +UMG_Datasets = [ + dict( + abbr='SciReasoner-unconditional_molecule_generation', + type=UMG_Dataset, + # max_cut=20, # Optionally limit the maximum number of samples + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg) +] +mini_UMG_Datasets = [ + dict( + abbr='SciReasoner-unconditional_molecule_generation-mini', + type=UMG_Dataset, + max_cut=150, # Optionally limit the maximum number of samples + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg) +] diff --git a/opencompass/configs/datasets/SciReasoner/UPG.py b/opencompass/configs/datasets/SciReasoner/UPG.py new file mode 100644 index 000000000..aee2d19e7 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/UPG.py @@ -0,0 +1,67 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import UPGDataset, UPG_postprocess, UPG_Evaluator + +reader_cfg = dict(input_columns=['input'], output_column='output') + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + '', + ], + round=[ + dict(role='HUMAN', prompt='{input}'), + ] + ), + ice_token='', + ), + ice_template=dict( + type=PromptTemplate, + template=dict( + round=[ + # dict(role='HUMAN', prompt='{input} /no_think'), # for Qwen3 + dict(role='HUMAN', prompt='{input}'), + dict(role='BOT', prompt='{output}'), + ] + ) + ), + # The retriever is responsible for retrieving examples and formatting them using ice_template + retriever=dict( + # type=FixKRetriever, + # fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples + type=ZeroRetriever, # For our trained models, use zero-shot + ), + inferencer=dict( + type=GenInferencer, + ), +) + +eval_cfg = dict( + evaluator=dict( + type=UPG_Evaluator, + ), + pred_postprocessor=dict(type=UPG_postprocess), + dataset_postprocessor=dict(type=UPG_postprocess), +) + +UPG_datasets = [ + dict( + abbr='SciReasoner-unconditional_protein_generation', + type=UPGDataset, + # max_cut=20, # Optionally limit the maximum number of samples + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg) +] +mini_UPG_datasets = [ + dict( + abbr='SciReasoner-unconditional_protein_generation-mini', + type=UPGDataset, + max_cut=150, # Optionally limit the maximum number of samples + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg) +] diff --git a/opencompass/configs/datasets/SciReasoner/bio_instruction_gen.py b/opencompass/configs/datasets/SciReasoner/bio_instruction_gen.py new file mode 100644 index 000000000..2e7972526 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/bio_instruction_gen.py @@ -0,0 +1,75 @@ +from mmengine.config import read_base +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import Bioinstruction_Dataset, bio_instruction_Evaluator + +reader_cfg = dict( + input_columns=['input'], + output_column='output' +) + +MODEL_NAME = r'model' + +bio_instruction_datasets = [] +mini_bio_instruction_datasets = [] + +path = ['antibody_antigen', 'rna_protein_interaction', 'emp', 'enhancer_activity', 'tf_m', 'Isoform', 'Modification', + 'MeanRibosomeLoading', 'ProgrammableRNASwitches', + 'CRISPROnTarget', 'promoter_enhancer_interaction', 'sirnaEfficiency', 'cpd', 'pd', 'tf_h'] +extra_path = ['Fluorescence', 'FunctionEC', 'Stability', 'Solubility', 'Thermostability'] # protein的这几个 + +for task in path: + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='{input}'), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + eval_cfg = dict( + evaluator=dict(type=bio_instruction_Evaluator, + path='opencompass/SciReasoner-bio_instruction', + task=task, + model_name=MODEL_NAME), + pred_role='BOT', + # num_gpus=1 + ) + eval_mini_cfg = dict( + evaluator=dict(type=bio_instruction_Evaluator, + path='opencompass/SciReasoner-bio_instruction', + task=task, + mini_set=True, + model_name=MODEL_NAME), + pred_role='BOT', + # num_gpus=1 + ) + + bio_instruction_datasets.append( + dict( + type=Bioinstruction_Dataset, + abbr=f'SciReasoner-bio_instruction-{task}', + path='opencompass/SciReasoner-bio_instruction', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + ) + mini_bio_instruction_datasets.append( + dict( + type=Bioinstruction_Dataset, + abbr=f'SciReasoner-bio_instruction-{task}-mini', + path='opencompass/SciReasoner-bio_instruction', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_mini_cfg, + ) + ) diff --git a/opencompass/configs/datasets/SciReasoner/bulk_modulus_material_gen.py b/opencompass/configs/datasets/SciReasoner/bulk_modulus_material_gen.py new file mode 100644 index 000000000..662be5925 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/bulk_modulus_material_gen.py @@ -0,0 +1,52 @@ +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.datasets import Bulk_modulus_material_Dataset, material_Evaluator, material_postprocessor + +modulus_material_reader = dict(input_columns=['input'], output_column='output') + +modulus_material_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{input}', + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +modulus_material_eval_cfg = dict( + evaluator=dict( + type=material_Evaluator, + data_path='opencompass/SciReasoner-Conditional_generation', + ), + pred_postprocessor=dict(type=material_postprocessor), +) + +modulus_material_datasets = [ + dict( + abbr='SciReasoner-bulk_modulus_to_material_generation', + type=Bulk_modulus_material_Dataset, + path='opencompass/SciReasoner-Conditional_generation', + reader_cfg=modulus_material_reader, + infer_cfg=modulus_material_infer_cfg, + eval_cfg=modulus_material_eval_cfg, + ) +] +mini_modulus_material_datasets = [ + dict( + abbr='SciReasoner-bulk_modulus_to_material_generation-mini', + type=Bulk_modulus_material_Dataset, + path='opencompass/SciReasoner-Conditional_generation', + mini_set=True, + reader_cfg=modulus_material_reader, + infer_cfg=modulus_material_infer_cfg, + eval_cfg=modulus_material_eval_cfg, + ) +] diff --git a/opencompass/configs/datasets/SciReasoner/composition_material_gen.py b/opencompass/configs/datasets/SciReasoner/composition_material_gen.py new file mode 100644 index 000000000..15beb5a86 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/composition_material_gen.py @@ -0,0 +1,64 @@ +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.datasets import Composition_material_Dataset, composition_Evaluator, material_postprocessor + +generation_kwargs = dict( + do_sample=True, + # top_p=0.8, + # min_p=0, + temperature=0.40, + # top_k=20, + # repetition_penalty=1, + # "<|endoftext|>": 151643 "<|im_end|>": 151645 + # eos_token_id=[151643, 151645], +) + +composition_material_reader = dict(input_columns=['input'], output_column='output') + +composition_material_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{input}', + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +composition_material_eval_cfg = dict( + evaluator=dict( + type=composition_Evaluator, + data_path='opencompass/SciReasoner-Conditional_generation', + ), + pred_postprocessor=dict(type=material_postprocessor), +) + + +composition_material_datasets = [ + dict( + abbr='SciReasoner-composition_to_material_generation', + type=Composition_material_Dataset, + path='opencompass/SciReasoner-Conditional_generation', + reader_cfg=composition_material_reader, + infer_cfg=composition_material_infer_cfg, + eval_cfg=composition_material_eval_cfg, + ) +] +mini_composition_material_datasets = [ + dict( + abbr='SciReasoner-composition_to_material_generation-mini', + type=Composition_material_Dataset, + path='opencompass/SciReasoner-Conditional_generation', + mini_set=True, + reader_cfg=composition_material_reader, + infer_cfg=composition_material_infer_cfg, + eval_cfg=composition_material_eval_cfg, + ) +] diff --git a/opencompass/configs/datasets/SciReasoner/mol_biotext_gen.py b/opencompass/configs/datasets/SciReasoner/mol_biotext_gen.py new file mode 100644 index 000000000..6b40ebdf6 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/mol_biotext_gen.py @@ -0,0 +1,110 @@ +# base config for LLM4Chem +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import Mol_Instructions_postprocess_BioText, Mol_Instructions_Evaluator_BioText, \ + Mol_Instructions_Dataset_BioText + +TASKS = [ + 'chemical_disease_interaction_extraction', + 'chemical_entity_recognition', + 'chemical_protein_interaction_extraction', + 'multi_choice_question', + 'open_question', + 'true_or_false_question' +] + +reader_cfg = dict(input_columns=['input'], output_column='output') + +mol_biotext_datasets = [] +mini_mol_biotext_datasets = [] + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt='{input}') + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +infer_cfg_true_or_false = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt="{input}Your answer should start with 'Yes' or 'Maybe' or 'No'") + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +infer_cfg_CER = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + # Optional but recommended: A system prompt for better instructions. + dict(role='SYSTEM', fallback_role='HUMAN', + prompt='There is a single choice question about chemistry. Answer the question directly.'), + # The placeholder is the ice_token string itself, used as a direct list element. + '', + ], + round=[ + dict(role='HUMAN', prompt='Query: {input}'), + dict(role='BOT', prompt=''), + ] + ), + ice_token='', + ), + ice_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='Query: {input}'), + dict(role='BOT', prompt='{output}'), + ] + ) + ), + retriever=dict( + type=FixKRetriever, + fix_id_list=[0, ], # 使用前1个示例 + ), + inferencer=dict(type=GenInferencer) +) + +for task in TASKS: + eval_cfg = dict( + evaluator=dict(type=Mol_Instructions_Evaluator_BioText, task=task), + pred_postprocessor=dict(type=Mol_Instructions_postprocess_BioText, task=task), + dataset_postprocessor=dict(type=Mol_Instructions_postprocess_BioText, task=task), + ) + + if task == 'true_or_false_question': + apply_infer_cfg = infer_cfg_true_or_false + elif task == 'chemical_entity_recognition': + apply_infer_cfg = infer_cfg_CER + else: + apply_infer_cfg = infer_cfg + + mol_biotext_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}', + type=Mol_Instructions_Dataset_BioText, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + reader_cfg=reader_cfg, + infer_cfg=apply_infer_cfg, + eval_cfg=eval_cfg) + ) + mini_mol_biotext_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}-mini', + type=Mol_Instructions_Dataset_BioText, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=apply_infer_cfg, + eval_cfg=eval_cfg) + ) diff --git a/opencompass/configs/datasets/SciReasoner/mol_molecule_gen.py b/opencompass/configs/datasets/SciReasoner/mol_molecule_gen.py new file mode 100644 index 000000000..c13fcb36c --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/mol_molecule_gen.py @@ -0,0 +1,63 @@ +# base config for LLM4Chem +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import Mol_Instructions_postprocess_Mol, Mol_Instructions_Evaluator_Mol, Mol_Instructions_Dataset + +TASKS = [ + 'property_prediction_str', + 'description_guided_molecule_design', + 'forward_reaction_prediction', + 'retrosynthesis', + 'reagent_prediction', + 'molecular_description_generation' + ] + +reader_cfg = dict(input_columns=['input'], output_column='output') + +mol_mol_datasets = [] +mini_mol_mol_datasets = [] + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt='{input}'), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +for task in TASKS: + eval_cfg = dict( + evaluator=dict(type=Mol_Instructions_Evaluator_Mol, task=task), + pred_postprocessor=dict(type=Mol_Instructions_postprocess_Mol, task=task), + dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Mol, task=task), + ) + + mol_mol_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}', + type=Mol_Instructions_Dataset, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ) + ) + mini_mol_mol_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}-mini', + type=Mol_Instructions_Dataset, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ) + ) + + + diff --git a/opencompass/configs/datasets/SciReasoner/mol_protein_gen.py b/opencompass/configs/datasets/SciReasoner/mol_protein_gen.py new file mode 100644 index 000000000..b5d814db3 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/mol_protein_gen.py @@ -0,0 +1,83 @@ +# base config for LLM4Chem +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import (Mol_Instructions_postprocess_Protein, Mol_Instructions_Evaluator_Protein, + Mol_Instructions_Dataset, Mol_Instructions_postprocess_Protein_Design, + Mol_Instructions_Evaluator_Protein_Design, Mol_Instructions_Dataset_Protein_Design) + +TASKS = [ + 'catalytic_activity', + 'domain_motif', + 'general_function', + 'protein_function', +] + +reader_cfg = dict(input_columns=['input'], output_column='output') + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt='{input}'), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +eval_cfg = dict( + evaluator=dict(type=Mol_Instructions_Evaluator_Protein), + pred_postprocessor=dict(type=Mol_Instructions_postprocess_Protein), + dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Protein), +) + +eval_cfg_protein_design = dict( + evaluator=dict(type=Mol_Instructions_Evaluator_Protein_Design), + pred_postprocessor=dict(type=Mol_Instructions_postprocess_Protein_Design), + dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Protein_Design), +) + +mol_protein_datasets = [] +mini_mol_protein_datasets = [] + +for task in TASKS: + mol_protein_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}', + type=Mol_Instructions_Dataset, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg)) + mini_mol_protein_datasets.append( + dict( + abbr=f'SciReasoner-mol_instruction_{task}-mini', + type=Mol_Instructions_Dataset, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg)) + +task = 'protein_design' +mol_protein_datasets.append( + dict( + abbr='SciReasoner-mol_instruction_protein_design', + type=Mol_Instructions_Dataset_Protein_Design, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg_protein_design)) +mini_mol_protein_datasets.append( + dict( + abbr='SciReasoner-mol_instruction_protein_design-mini', + type=Mol_Instructions_Dataset_Protein_Design, + path='opencompass/SciReasoner-Mol_Instructions', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg_protein_design)) diff --git a/opencompass/configs/datasets/SciReasoner/opi_gen.py b/opencompass/configs/datasets/SciReasoner/opi_gen.py new file mode 100644 index 000000000..5ad93dccc --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/opi_gen.py @@ -0,0 +1,83 @@ +# base config for opi +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import opi_postprocess, Opi_Evaluator, OpiDataset + + +all_datasets = [] +mini_all_datasets = [] + +# Root directory where the datasets are located +root_dir = '/path/OPI_test' + +subtask_dirs = [ + 'EC_number_CLEAN_EC_number_new', + 'EC_number_CLEAN_EC_number_price', + 'Fold_type_fold_type', + 'Function_CASPSimilarSeq_function', + 'Function_IDFilterSeq_function', + 'Function_UniProtSeq_function', + 'gName2Cancer_gene_name_to_cancer', + 'GO_CASPSimilarSeq_go', + 'GO_IDFilterSeq_go', + 'GO_UniProtSeq_go', + 'gSymbol2Cancer_gene_symbol_to_cancer', + 'gSymbol2Tissue_gene_symbol_to_tissue', + 'Keywords_CASPSimilarSeq_keywords', + 'Keywords_IDFilterSeq_keywords', + 'Keywords_UniProtSeq_keywords', + 'Subcellular_localization_subcell_loc', +] + +for subtask_name in subtask_dirs: + # Common configs for inference + + reader_cfg = dict(input_columns=['input'], output_column='output') + + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt='{input}'), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict( + type=GenInferencer, + ) + ) + + # Extract high-level task from subdir name for the evaluator (e.g., 'EC_number') + task_type = subtask_name.split('_')[0] + + eval_cfg = dict( + evaluator=dict(type=Opi_Evaluator, task=task_type), + pred_postprocessor=dict(type=opi_postprocess, task=task_type), + dataset_postprocessor=dict(type=opi_postprocess, task=task_type), + ) + + # Create the dataset dictionary for the current subtask + all_datasets.append( + dict( + abbr=f'SciReasoner-Opi_{subtask_name}', + type=OpiDataset, + path='opencompass/SciReasoner-OPI', + task=subtask_name, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ).copy() + ) + mini_all_datasets.append( + dict( + abbr=f'SciReasoner-Opi_{subtask_name}-mini', + type=OpiDataset, + path='opencompass/SciReasoner-OPI', + task=subtask_name, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ).copy() + ) diff --git a/opencompass/configs/datasets/SciReasoner/peer_gen.py b/opencompass/configs/datasets/SciReasoner/peer_gen.py new file mode 100644 index 000000000..989663a24 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/peer_gen.py @@ -0,0 +1,99 @@ +# base config for LLM4Chem +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import PEER_postprocess, PEER_Evaluator, PEER_Dataset, PEER_postprocess_float_compare, \ + PEER_postprocess_default + +TASKS = [ + 'solubility', + 'stability', + 'human_ppi', + 'yeast_ppi', +] + +reader_cfg = dict(input_columns=['input'], output_column='output') + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt='{input}.'), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict( + type=GenInferencer, + # max_out_len=2048, + ) +) + +eval_cfg = dict( + evaluator=dict(type=PEER_Evaluator), + pred_postprocessor=dict(type=PEER_postprocess), + dataset_postprocessor=dict(type=PEER_postprocess), +) + +# use default postprocess to remain the original output for LLM judgement. +# PEER_postprocess will be used in the evaluation stage to compare the output with the ground truth as a fast comparison. +eval_llm_cfg = dict( + evaluator=dict(type=PEER_Evaluator, + openai_key='EMPTY', gpt_model='gpt-4.1-mini'), + pred_postprocessor=dict(type=PEER_postprocess_default), + dataset_postprocessor=dict(type=PEER_postprocess_default), +) + +eval_stability_cfg = dict( + evaluator=dict(type=PEER_Evaluator, task='stability'), + pred_postprocessor=dict(type=PEER_postprocess_float_compare, compare_number=1), + dataset_postprocessor=dict(type=PEER_postprocess_float_compare, compare_number=1), +) + +PEER_datasets = [] +mini_PEER_datasets = [] + +for task in TASKS: + if task != 'stability': + PEER_datasets.append( + dict( + abbr=f'SciReasoner-PEER_{task}', + type=PEER_Dataset, + path='opencompass/SciReasoner-PEER', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_llm_cfg), + ) + mini_PEER_datasets.append( + dict( + abbr=f'SciReasoner-PEER_{task}-mini', + type=PEER_Dataset, + path='opencompass/SciReasoner-PEER', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_llm_cfg), + ) + else: + PEER_datasets.append( + dict( + abbr=f'SciReasoner-PEER_{task}', + type=PEER_Dataset, + path='opencompass/SciReasoner-PEER', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_stability_cfg), + ) + mini_PEER_datasets.append( + dict( + abbr=f'SciReasoner-PEER_{task}-mini', + type=PEER_Dataset, + path='opencompass/SciReasoner-PEER', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_stability_cfg), + ) diff --git a/opencompass/configs/datasets/SciReasoner/retrosynthesis_USPTO_gen.py b/opencompass/configs/datasets/SciReasoner/retrosynthesis_USPTO_gen.py new file mode 100644 index 000000000..5600301d8 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/retrosynthesis_USPTO_gen.py @@ -0,0 +1,74 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import RetrosynthesisEvaluator, Retrosynthesis_postprocess, LLM4ChemDataset + +reader_cfg = dict(input_columns=['input'], output_column='output') + + + +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + dict(role='SYSTEM', fallback_role='HUMAN', prompt=''), + '', + ], + round = [ + # dict(role='HUMAN', prompt='Query: {input} /no_think'), # for Qwen3 + dict(role='HUMAN', prompt='{input}'), + ] + ), + ice_token='', + ), + ice_template=dict( + type=PromptTemplate, + template = dict( + round = [ + dict(role='HUMAN', prompt='{input}'), + ] + ) + ), + # retriever: responsible for retrieving and formatting examples using ice_template + retriever=dict( + # type=FixKRetriever, + # fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples + type=ZeroRetriever, # For our trained model, use zero-shot + ), + inferencer=dict( + type=GenInferencer, + ), +) + +eval_cfg = dict( + evaluator=dict(type=RetrosynthesisEvaluator, beam_size=1, n_best=1), + pred_postprocessor=dict(type=Retrosynthesis_postprocess), + dataset_postprocessor=dict(type=Retrosynthesis_postprocess), +) + +task = 'retrosynthesis_uspto50k' + +Retrosynthesis_datasets = [ + dict( + abbr='SciReasoner-retrosynthesis_USPTO_50K', + type=LLM4ChemDataset, + path='opencompass/SciReasoner-smol', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg= eval_cfg, + ) +] +mini_Retrosynthesis_datasets = [ + dict( + abbr='SciReasoner-retrosynthesis_USPTO_50K-mini', + type=LLM4ChemDataset, + path='opencompass/SciReasoner-smol', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg= eval_cfg, + ) +] \ No newline at end of file diff --git a/opencompass/configs/datasets/SciReasoner/scireasoner_gen.py b/opencompass/configs/datasets/SciReasoner/scireasoner_gen.py new file mode 100644 index 000000000..66e0fc715 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/scireasoner_gen.py @@ -0,0 +1,55 @@ +from mmengine.config import read_base + +with read_base(): + # scireasoner + from opencompass.configs.datasets.SciReasoner.bio_instruction_gen import bio_instruction_datasets, \ + mini_bio_instruction_datasets + from opencompass.configs.datasets.SciReasoner.composition_material_gen import \ + composition_material_datasets, mini_composition_material_datasets + from opencompass.configs.datasets.SciReasoner.GUE_gen import GUE_datasets, mini_GUE_datasets + from opencompass.configs.datasets.SciReasoner.smol_gen import all_datasets as smol_datasets, \ + mini_all_datasets as mini_smol_datasets + from opencompass.configs.datasets.SciReasoner.retrosynthesis_USPTO_gen import \ + Retrosynthesis_datasets as Retrosynthesis_uspto50k_datasets, \ + mini_Retrosynthesis_datasets as mini_Retrosynthesis_uspto50k_datasets + from opencompass.configs.datasets.SciReasoner.LLM4Mat_gen import LLM4Mat_datasets, mini_LLM4Mat_datasets + from opencompass.configs.datasets.SciReasoner.bulk_modulus_material_gen import modulus_material_datasets, \ + mini_modulus_material_datasets + from opencompass.configs.datasets.SciReasoner.mol_biotext_gen import mol_biotext_datasets, mini_mol_biotext_datasets + from opencompass.configs.datasets.SciReasoner.mol_molecule_gen import mol_mol_datasets, mini_mol_mol_datasets + from opencompass.configs.datasets.SciReasoner.mol_protein_gen import mol_protein_datasets, mini_mol_protein_datasets + from opencompass.configs.datasets.SciReasoner.opi_gen import all_datasets as opi_datasets, \ + mini_all_datasets as mini_opi_datasets + from opencompass.configs.datasets.SciReasoner.peer_gen import PEER_datasets, mini_PEER_datasets + from opencompass.configs.datasets.SciReasoner.unconditional_material_gen import uncond_material_datasets, \ + mini_uncond_material_datasets + from opencompass.configs.datasets.SciReasoner.unconditional_RNA_gen import uncond_RNA_datasets, \ + mini_uncond_RNA_datasets + from opencompass.configs.datasets.SciReasoner.UPG import \ + UPG_datasets as uncond_protein_datasets, mini_UPG_datasets as mini_uncond_protein_datasets + from opencompass.configs.datasets.SciReasoner.UMG import UMG_Datasets, mini_UMG_Datasets + +# full eval set +scireasoner_datasets_full = bio_instruction_datasets + composition_material_datasets + GUE_datasets + smol_datasets + \ + Retrosynthesis_uspto50k_datasets + LLM4Mat_datasets + modulus_material_datasets + \ + mol_biotext_datasets + mol_mol_datasets + mol_protein_datasets + opi_datasets + PEER_datasets + \ + uncond_material_datasets + uncond_RNA_datasets + uncond_protein_datasets + UMG_Datasets + +# mini eval set +scireasoner_datasets_mini = mini_bio_instruction_datasets + mini_composition_material_datasets + mini_GUE_datasets + mini_smol_datasets + \ + mini_Retrosynthesis_uspto50k_datasets + mini_LLM4Mat_datasets + mini_modulus_material_datasets + \ + mini_mol_biotext_datasets + mini_mol_mol_datasets + mini_mol_protein_datasets + mini_opi_datasets + mini_PEER_datasets + \ + mini_uncond_material_datasets + mini_uncond_RNA_datasets + mini_uncond_protein_datasets + mini_UMG_Datasets + +# scireasoner_mini_datasets =\ +# ( +# # mini_bio_instruction_datasets + +# # mini_composition_material_datasets + +# # mini_modulus_material_datasets + +# # mini_GUE_datasets + +# # mini_LLM4Mat_datasets + +# # mini_mol_biotext_datasets + mini_mol_mol_datasets + mini_mol_protein_datasets + mini_opi_datasets + mini_Retrosynthesis_uspto50k_datasets + mini_smol_datasets +# # mini_UMG_Datasets + mini_uncond_material_datasets +# mini_uncond_RNA_datasets +# # mini_uncond_protein_datasets +# ) \ No newline at end of file diff --git a/opencompass/configs/datasets/SciReasoner/smol_gen.py b/opencompass/configs/datasets/SciReasoner/smol_gen.py new file mode 100644 index 000000000..cdd9a8556 --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/smol_gen.py @@ -0,0 +1,120 @@ +# base config for LLM4Chem +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import LLM4Chem_postprocess, LLM4Chem_Evaluator, LLM4ChemDataset + +TASKS = ( + 'forward_synthesis', + 'retrosynthesis', + 'molecule_captioning', + 'molecule_generation', + 'name_conversion-i2f', + 'name_conversion-i2s', + 'name_conversion-s2f', + 'name_conversion-s2i', + 'property_prediction-esol', + 'property_prediction-lipo', + 'property_prediction-bbbp', + 'property_prediction-clintox', + 'property_prediction-hiv', + 'property_prediction-sider', +) + +TASKS_single = ( + 'property_prediction-esol', + 'property_prediction-lipo', + 'property_prediction-bbbp', + 'property_prediction-clintox', + 'property_prediction-hiv', + 'property_prediction-sider', +) + +TASK_TAGS = { + 'forward_synthesis': ('', ''), + 'retrosynthesis': ('', ''), + 'molecule_generation': ('', ''), + 'molecule_captioning': (None, None), + 'name_conversion-i2f': ('', ''), + 'name_conversion-i2s': ('', ''), + 'name_conversion-s2f': ('', ''), + 'name_conversion-s2i': ('', ''), + 'property_prediction-esol': ('', ''), + 'property_prediction-lipo': ('', ''), + 'property_prediction-bbbp': ('', ''), + 'property_prediction-clintox': ('', ''), + 'property_prediction-hiv': ('', ''), + 'property_prediction-sider': ('', ''), +} + +all_datasets = [] +mini_all_datasets = [] + +for task in TASKS: + + reader_cfg = dict(input_columns=['input'], output_column='output') + + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + # Optional but recommended: A system prompt for better instructions. + dict(role='SYSTEM', fallback_role='HUMAN', prompt=''), + # The placeholder is the ice_token string itself, used as a direct list element. + '', + ], + round=[ + dict(role='HUMAN', prompt=f'{{input}}'), + ] + ), + ice_token='', + ), + ice_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='{input}'), + dict(role='BOT', prompt='{output}'), + ] + ) + ), + # retriever is responsible for retrieving examples and using ice_template to format them + retriever=dict( + # type=FixKRetriever, + # fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples + type=ZeroRetriever, # For our trained model, use zero-shot + ), + inferencer=dict( + type=GenInferencer, + )) + + eval_cfg = dict( + evaluator=dict(type=LLM4Chem_Evaluator, task=task), + pred_postprocessor=dict(type=LLM4Chem_postprocess, task=task), + dataset_postprocessor=dict(type=LLM4Chem_postprocess, task=task), + ) + + all_datasets.append( + dict( + abbr='SciReasoner-smol_' + task, + type=LLM4ChemDataset, + path='opencompass/SciReasoner-smol', + task=task, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ) + ) + mini_all_datasets.append( + dict( + abbr='SciReasoner-smol_' + task + '-mini', + type=LLM4ChemDataset, + path='opencompass/SciReasoner-smol', + task=task, + mini_set=True, + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg + ) + ) diff --git a/opencompass/configs/datasets/SciReasoner/unconditional_RNA_gen.py b/opencompass/configs/datasets/SciReasoner/unconditional_RNA_gen.py new file mode 100644 index 000000000..cc6eb949e --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/unconditional_RNA_gen.py @@ -0,0 +1,51 @@ +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.datasets import Uncond_RNA_Dataset, RNA_Evaluator, RNA_postprocessor + +uncond_RNA_reader_cfg = dict(input_columns=['input'], output_column='output') + + +uncond_RNA_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{input}', + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +uncond_RNA_eval_cfg = dict( + evaluator=dict(type=RNA_Evaluator), + pred_postprocessor=dict(type=RNA_postprocessor), +) + +uncond_RNA_datasets = [ + dict( + abbr='SciReasoner-unconditional_RNA_generation', + type=Uncond_RNA_Dataset, + num=5000, + prompt='Please generate a novel RNA sequence. ', + reader_cfg=uncond_RNA_reader_cfg, + infer_cfg=uncond_RNA_infer_cfg, + eval_cfg=uncond_RNA_eval_cfg, + ) +] +mini_uncond_RNA_datasets = [ + dict( + abbr='SciReasoner-unconditional_RNA_generation-mini', + type=Uncond_RNA_Dataset, + num=150, + prompt='Please generate a novel RNA sequence. ', + reader_cfg=uncond_RNA_reader_cfg, + infer_cfg=uncond_RNA_infer_cfg, + eval_cfg=uncond_RNA_eval_cfg, + ) +] diff --git a/opencompass/configs/datasets/SciReasoner/unconditional_material_gen.py b/opencompass/configs/datasets/SciReasoner/unconditional_material_gen.py new file mode 100644 index 000000000..299ad47fd --- /dev/null +++ b/opencompass/configs/datasets/SciReasoner/unconditional_material_gen.py @@ -0,0 +1,50 @@ +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.datasets import Uncond_material_Dataset, uncond_material_Evaluator, material_postprocessor + +uncond_material_reader_cfg = dict(input_columns=['input'], output_column='output') + +uncond_material_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt='{input}', + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +uncond_material_eval_cfg = dict( + evaluator=dict(type=uncond_material_Evaluator), + pred_postprocessor=dict(type=material_postprocessor), + ) + +uncond_material_datasets = [ + dict( + abbr='SciReasoner-unconditional_material_generation', + type=Uncond_material_Dataset, + num=5000, + prompt='Produce a material that has any bulk modulus or composition', + reader_cfg=uncond_material_reader_cfg, + infer_cfg=uncond_material_infer_cfg, + eval_cfg=uncond_material_eval_cfg, + ) +] +mini_uncond_material_datasets = [ + dict( + abbr='SciReasoner-unconditional_material_generation-mini', + type=Uncond_material_Dataset, + num=150, + prompt='Produce a material that has any bulk modulus or composition', + reader_cfg=uncond_material_reader_cfg, + infer_cfg=uncond_material_infer_cfg, + eval_cfg=uncond_material_eval_cfg, + ) +] diff --git a/opencompass/datasets/SciReasoner/GUE.py b/opencompass/datasets/SciReasoner/GUE.py new file mode 100644 index 000000000..7b8bc5320 --- /dev/null +++ b/opencompass/datasets/SciReasoner/GUE.py @@ -0,0 +1,210 @@ +# flake8: noqa + +import json +import os +import re +from typing import Union + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from sklearn.metrics import matthews_corrcoef + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +@LOAD_DATASET.register_module() +class GUE_Dataset(BaseDataset): + + @staticmethod + def load(path, task, mini_set=False): + + # if (hf_hub is True): + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + def augment_output(data): + for item in data: + label = item.get('meta_data', {}).get('label', '') + item['output'] += f' The prediction result is {label}.' + return data + + train_data = augment_output(train_data[:5]) + test_data = augment_output(test_data) + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +def remove_think_tags(text: str) -> str: + if '' not in text: + return text + if '' not in text: + return '' + return re.sub(r'.*?', '', text, flags=re.DOTALL) + + +@TEXT_POSTPROCESSORS.register_module() +def GUE_postprocessor(text: Union[str, None]) -> str: + if not isinstance(text, str): + return '' + + text = text.strip() + text = remove_think_tags(text) + + if text == '': + return '' + + match = re.search(r'\bThe prediction result is\s+(positive|negative)\b', + text, re.IGNORECASE) + if match: + return match.group(1).lower() + + positive_patterns = [ + r'\bpositive\b', + r'\bpositively\b', + r'\bpresence\b', + r'\bdetected\b', + r'\bidentified\b', + r'\bidentifiable\b', + r'\bfound\b', + r'\byes\b', + r'\blocated\b', + r'\bdetectable\b', + r'\bobservable\b', + r'\bevident\b', + r'\babsolutely\b', + r'\baffirmative\b', + r'\bcan\b', + r'\baffirm\b', + r'\bconfirm\b', + r'\bconfirms\b', + r'\breveals\b', + r'\bexistence\b', + r'\bcertainly\b', + r'\bconsistent\b', + r'\brecognizable\b', + r'\bshows core\b', + r'\bshows promoter\b', + r'\bshows characteristic\b', + r'\bevidenced by\b', + r'\bseeing characteristic patterns\b', + r'\bincludes\b', + r'\bcontains sequences\b', + r'\bexhibits clear\b', + r'\bcontains transcription\b', + r'\bexhibits sequences\b', + r'\bclearly contains\b', + r'\brecognized\b', + r'\bexhibits features\b', + r'\bcontains regulatory\b', + r'\bshows clear\b', + r'\bdisplays\b', + r'\bdefinitely has\b', + r'\bexhibits patterns\b', + r'\bclear evidence\b', + r'\bcontains a\b', + r'\byep\b', + r'\bcontains sites\b', + r'\bshows sequences\b', + ] + + negative_patterns = [ + r'\bnegative\b', + r'\bno\b', + r'\babsence\b', + r'\bnot\b', + r'\bcannot\b', + r'\bfails\b', + r'\babsent\b', + r'\blacks\b', + ] + + for pattern in negative_patterns: + if re.search(pattern, text, re.IGNORECASE): + return 'negative' + + for pattern in positive_patterns: + if re.search(pattern, text, re.IGNORECASE): + return 'positive' + + return '' + + +class GUE_Evaluator(BaseEvaluator): + + def score(self, predictions, references): + + def normalize(label): + label = label.strip().lower() + if label == 'positive': + return 1 + elif label == 'negative': + return 0 + else: + return None + + total_count = len(predictions) + + if isinstance(predictions[0], list): + predictions = [p[0] for p in predictions] + + pred_bin_all = [ + 1 if p.strip().lower() == 'positive' else 0 for p in predictions + ] + ref_bin_all = [ + 1 if r.strip().lower() == 'positive' else 0 for r in references + ] + mcc_all = matthews_corrcoef(ref_bin_all, pred_bin_all) + + filtered_pred = [] + filtered_ref = [] + skipped = 0 + + for p, r in zip(predictions, references): + p_norm = normalize(p) + r_norm = normalize(r) + if p_norm is None or r_norm is None: + skipped += 1 + continue + filtered_pred.append(p_norm) + filtered_ref.append(r_norm) + + if filtered_pred: + mcc_filtered = matthews_corrcoef(filtered_ref, filtered_pred) + else: + mcc_filtered = 0.0 + + return { + 'matthews_correlation_all': mcc_all * 100, + 'matthews_correlation_filtered': mcc_filtered * 100, + 'non_pos_neg_count': skipped, + 'total_count': total_count + } diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/__init__.py b/opencompass/datasets/SciReasoner/LLM4Chem/__init__.py new file mode 100644 index 000000000..344524879 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/__init__.py @@ -0,0 +1,12 @@ +from .config import TASK_TAGS as LLM4Chem_TASK_TAGS # noqa: F401, F403 +from .config import TASKS as LLM4Chem_TASKS # noqa: F401, F403 +from .config import \ + TASKS_GENERATION_SETTINGS as \ + LLM4Chem_TASKS_GENERATION_SETTINGS # noqa: F401, F403 +from .evaluator import LLM4Chem_Evaluator # noqa: F401 +from .evaluator import LLM4Chem_postprocess # noqa: F401 +from .evaluator import LLM4ChemDataset # noqa: F401, F403 +from .retrosynthesis_evaluator import \ + Retrosynthesis_postprocess # noqa: F401, F403 +from .retrosynthesis_evaluator import \ + RetrosynthesisEvaluator # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/config.py b/opencompass/datasets/SciReasoner/LLM4Chem/config.py new file mode 100644 index 000000000..36d83cece --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/config.py @@ -0,0 +1,166 @@ +TASKS = ( + 'forward_synthesis', + 'retrosynthesis', + 'molecule_captioning', + 'molecule_generation', + 'name_conversion-i2f', + 'name_conversion-i2s', + 'name_conversion-s2f', + 'name_conversion-s2i', + 'property_prediction-esol', + 'property_prediction-lipo', + 'property_prediction-bbbp', + 'property_prediction-clintox', + 'property_prediction-hiv', + 'property_prediction-sider', +) + +DEFAULT_MAX_INPUT_TOKENS = 512 +DEFAULT_MAX_NEW_TOKENS = 1024 + +TASKS_GENERATION_SETTINGS = { + 'forward_synthesis': { + 'generation_kargs': { + 'num_return_sequences': 5, + 'num_beams': 8 + }, + }, + 'retrosynthesis': { + 'max_new_tokens': 960, + 'generation_kargs': { + 'num_return_sequences': 10, + 'num_beams': 13 + }, + }, + 'molecule_captioning': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4 + }, + }, + 'molecule_generation': { + 'generation_kargs': { + 'num_return_sequences': 5, + 'num_beams': 8 + }, + }, + 'name_conversion-i2f': { + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 3, + 'num_beams': 6 + }, + }, + 'name_conversion-i2s': { + 'generation_kargs': { + 'num_return_sequences': 5, + 'num_beams': 8 + }, + }, + 'name_conversion-s2f': { + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 3, + 'num_beams': 6 + }, + }, + 'name_conversion-s2i': { + 'generation_kargs': { + 'num_return_sequences': 5, + 'num_beams': 8 + }, + }, + 'property_prediction-esol': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, + 'property_prediction-lipo': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, + 'property_prediction-bbbp': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, + 'property_prediction-clintox': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, + 'property_prediction-hiv': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, + 'property_prediction-sider': { + 'batch_size': 16, + 'max_new_tokens': 20, + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 4, + }, + }, +} + +TASK_TAGS = { + 'forward_synthesis': ('', ''), + 'retrosynthesis': ('', ''), + 'molecule_generation': ('', ''), + 'molecule_captioning': (None, None), + 'name_conversion-i2f': ('', ''), + 'name_conversion-i2s': ('', ''), + 'name_conversion-s2f': ('', ''), + 'name_conversion-s2i': ('', ''), + 'property_prediction-esol': ('', ''), + 'property_prediction-lipo': ('', ''), + 'property_prediction-bbbp': ('', ''), + 'property_prediction-clintox': ('', ''), + 'property_prediction-hiv': ('', ''), + 'property_prediction-sider': ('', ''), +} + +# These tasks output SMILES, where there may be semicolons +# that separate different parts. To facilitate evaluation, +# each semicolon is replaced by a dot. +TASKS_WITH_SEMICOLON_REPLACE = ( + 'forward_synthesis', + 'retrosynthesis', + 'molecule_generation', + 'name_conversion-i2s', +) + +# For these tasks, one input might have multiple gold answers, +# so the gold answer should be directly obtained from the dataset +# instead of directly using the gold domain of each sample. +TASKS_WITH_READING_GOLD_FROM_DATASET = ('forward_synthesis', 'retrosynthesis', + 'molecule_generation', + 'molecule_captioning', + 'name_conversion-i2f', + 'name_conversion-i2s', + 'name_conversion-s2f', + 'name_conversion-s2i') + +BASE_MODELS = { + 'osunlp/LlaSMol-Mistral-7B': 'mistralai/Mistral-7B-v0.1', + 'osunlp/LlaSMol-Galactica-6.7B': 'facebook/galactica-6.7b', + 'osunlp/LlaSMol-Llama2-7B': 'meta-llama/Llama-2-7b-hf', + 'osunlp/LlaSMol-CodeLlama-7B': 'codellama/CodeLlama-7b-hf', +} diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/evaluator.py b/opencompass/datasets/SciReasoner/LLM4Chem/evaluator.py new file mode 100644 index 000000000..3262a8a42 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/evaluator.py @@ -0,0 +1,228 @@ +# flake8: noqa +# NC-I2S NC-S2I task +# https://github.com/OSU-NLP-Group/LLM4Chem + +import json +import os +import re + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + +from .config import TASK_TAGS, TASKS_WITH_SEMICOLON_REPLACE +from .utils.metrics import (calculate_boolean_metrics, + calculate_formula_metrics, + calculate_number_metrics, calculate_smiles_metrics, + calculate_text_metrics) + + +@LOAD_DATASET.register_module() +class LLM4ChemDataset(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + + if (max_cut != -1): + test_data = test_data[:max_cut] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 50) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +def extract_answer_part(outputs, left_tag, right_tag, mode='tag'): + assert mode in ('tag', 'direct') + + assert isinstance(outputs, list) + answers = [] + for text in outputs: + if mode == 'direct' or (left_tag is None and right_tag is None): + text = text.replace('', '').replace('', '').strip() + answers.append(text.strip()) + continue + + left_tag_pos = text.find(left_tag) + if left_tag_pos == -1: + answers.append('') + continue + right_tag_pos = text.find(right_tag) + if right_tag_pos == -1: + answers.append('') + continue + text = text[left_tag_pos + len(left_tag):right_tag_pos].strip() + answers.append(text) + return answers + + +@TEXT_POSTPROCESSORS.register_module('LLM4Chem_postprocess') +def LLM4Chem_postprocess(text, task, *args, **kwargs): + # 删除 里的内容 + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + replace_semicolon = task in TASKS_WITH_SEMICOLON_REPLACE + pred = extract_answer_part([text], *(TASK_TAGS[task]), mode='tag')[0] + # task in TASKS_WITH_SEMICOLON_REPLACE needs semicolon + # replaced with a period + if replace_semicolon: + pred = pred.replace(';', '.') + # no matched tag + if pred == '': + tag = TASK_TAGS[task][0] + + if (tag == ''): + # 找到 text 的最后一个 yes/true/no/false,不区分大小写 + ans = re.findall(r'\b(?:yes|true|no|false)\b', text, re.IGNORECASE) + if ans: + # if ans[-1] 是 yes/true + if ans[-1].lower() in ('yes', 'true'): + return 'Yes' + else: + return 'No' + else: + return '' + + if (tag == ''): + # 找到 text 的最后一个数字 + # 去掉 text 里 里的内容 + text_2 = re.sub(r'.*?', '', text, flags=re.DOTALL) + ans = re.findall(r'-?\d*\.\d+|-?\d+', text_2) + if ans: + return ans[-1] + else: + return '' + + if (tag == ''): + # 找到 text 的最后一个化学式 + ans = re.findall( + r'[\[\(]?[A-Z][a-z]?\d*(?:\([A-Za-z0-9]+\)\d*)?[\]\)]?' + r'(?:[A-Z][a-z]?\d*|\([^\)]+\)\d*|\[[^\]]+\]\d*)' + r'*(?:[+-]{1,2})?(?:·\d*[A-Z][a-z]?\d*)*', text) + if ans: + return ans[-1] + else: + return '' + + # print(f"prediction: {pred}") + return pred + + +class LLM4Chem_Evaluator(BaseEvaluator): + + def __init__(self, task, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task = task + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + + task = self.task + pred_list = predictions + gold_list = references + + if task in ('property_prediction-esol', 'property_prediction-lipo', + 'property_prediction-bbbp', 'property_prediction-clintox', + 'property_prediction-hiv', 'property_prediction-sider'): + # set pred_list to [length * 1] + pred_list = [[pred[0]] for pred in pred_list] + + if task in ('forward_synthesis', 'molecule_generation', + 'name_conversion-i2s'): + r = calculate_smiles_metrics(pred_list, gold_list) + elif task in ('retrosynthesis', ): + r = calculate_smiles_metrics(pred_list, + gold_list, + metrics=('exact_match', 'fingerprint', + 'multiple_match')) + elif task in ('molecule_captioning', ): + r = calculate_text_metrics( + pred_list, + gold_list, + text_model='allenai/scibert_scivocab_uncased', + text_trunc_length=2048, + ) + elif task in ('name_conversion-i2f', 'name_conversion-s2f'): + r = calculate_formula_metrics(pred_list, + gold_list, + metrics=('element_match', )) + elif task in ('name_conversion-s2i', ): + r = calculate_formula_metrics(pred_list, + gold_list, + metrics=('split_match', )) + elif task in ('property_prediction-esol', 'property_prediction-lipo'): + r = calculate_number_metrics(pred_list, gold_list) + elif task in ('property_prediction-bbbp', + 'property_prediction-clintox', 'property_prediction-hiv', + 'property_prediction-sider'): + r = calculate_boolean_metrics(pred_list, gold_list) + else: + raise ValueError(task) + + if 'num_t1_exact_match' in r and 'num_all' in r: + # 100%, 2 位小数 + r['top1_exact_match'] = round( + r['num_t1_exact_match'] / r['num_all'] * 100, 2) + if 'num_t5_exact_match' in r and 'num_all' in r: + # 100%, 2 位小数 + r['top5_exact_match'] = round( + r['num_t5_exact_match'] / r['num_all'] * 100, 2) + if 'num_t1_ele_match' in r and 'num_all' in r: + # 100%, 2 位小数 + r['top1_ele_match'] = round( + r['num_t1_ele_match'] / r['num_all'] * 100, 2) + if 'num_correct' in r and 'num_all' in r: + r['accuracy'] = round(r['num_correct'] / r['num_all'] * 100, 2) + if 'num_t1_split_match' in r and 'num_all' in r: + # 100%, 2 位小数 + r['top1_split_match'] = round( + r['num_t1_split_match'] / r['num_all'] * 100, 2) + if 'num_t5_split_match' in r and 'num_all' in r: + # 100%, 2 位小数 + r['top5_split_match'] = round( + r['num_t5_split_match'] / r['num_all'] * 100, 2) + + return r diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/retrosynthesis_evaluator.py b/opencompass/datasets/SciReasoner/LLM4Chem/retrosynthesis_evaluator.py new file mode 100644 index 000000000..1305c20c4 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/retrosynthesis_evaluator.py @@ -0,0 +1,449 @@ +# dataset: USPTO-50K +# https://github.com/otori-bird/retrosynthesis +# task : retrosynthesis prediction +import multiprocessing +import re +from functools import partial +from typing import Union + +try: + from rdkit import Chem, RDLogger +except Exception: + Chem, RDLogger = None, None + +from tqdm import tqdm + +from opencompass.openicl import BaseEvaluator +from opencompass.registry import TEXT_POSTPROCESSORS + +# 关闭 RDKit 的冗余日志输出 +# lg = RDLogger.logger() +# lg.setLevel(RDLogger.CRITICAL) + +# ---------------------------------------------------------------------- +# 1. 复用原脚本的核心函数 +# 我们将这些函数放在文件顶部,以便在 Evaluator 中调用 +# ---------------------------------------------------------------------- + + +def smi_tokenizer(smi): + """ + Tokenizes a SMILES string using a regular expression. + Note: This function was in the original script but is not directly used + in the evaluation logic. It's included for completeness. + """ + pattern = (r'(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)' + r'|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])') + regex = re.compile(pattern) + tokens = [token for token in regex.findall(smi)] + assert smi == ''.join(tokens), f'SMILES tokenization failed for: {smi}' + return ' '.join(tokens) + + +def canonicalize_smiles_clear_map(smiles, synthon=False, return_max_frag=True): + """ + Canonicalizes a SMILES string, clears atom map numbers, and optionally + returns the largest fragment. + + Args: + smiles (str): The SMILES string to process. + synthon (bool): Whether to skip the sanitization step. + return_max_frag (bool): If True, returns a tuple of + (full_smiles, max_frag_smiles). + Otherwise, returns only the full SMILES. + + Returns: + A tuple (str, str) or a single str depending on return_max_frag. + """ + mol = Chem.MolFromSmiles(smiles, sanitize=not synthon) + if mol is not None: + # Clear atom map numbers + for atom in mol.GetAtoms(): + if atom.HasProp('molAtomMapNumber'): + atom.ClearProp('molAtomMapNumber') + try: + smi = Chem.MolToSmiles(mol, isomericSmiles=True) + except Exception: + # Handle cases where MolToSmiles fails + if return_max_frag: + return '', '' + else: + return '' + + if return_max_frag: + sub_smi_list = smi.split('.') + if len(sub_smi_list) > 1: + # Find the largest fragment + sub_mols = [(s, Chem.MolFromSmiles(s, sanitize=not synthon)) + for s in sub_smi_list] + sub_mol_sizes = [(smi, len(m.GetAtoms())) + for smi, m in sub_mols if m is not None] + if sub_mol_sizes: + # Sort fragments by size and return the largest one + max_frag_smi = sorted(sub_mol_sizes, + key=lambda x: x[1], + reverse=True)[0][0] + # Recursively canonicalize the largest fragment + return smi, canonicalize_smiles_clear_map( + max_frag_smi, synthon=synthon, return_max_frag=False) + else: + return smi, '' + else: + # If no fragments, the molecule is its own largest fragment + return smi, smi + else: + return smi + else: + # If the molecule is invalid from the start + if return_max_frag: + return '', '' + else: + return '' + + +def compute_rank(prediction_group, + beam_size, + n_best, + score_alpha=1.0, + raw=False): + """ + Ranks predictions for a single sample across multiple augmentations. + + Args: + prediction_group (list): A 2D list of predictions for one sample, + shaped [augmentation, beam_size]. + Each prediction is a tuple + (full_smi, max_frag_smi). + beam_size (int): The number of beams used in generation. + n_best (int): The number of top predictions to consider. + score_alpha (float): The scoring decay factor. + raw (bool): If True, assumes no test augmentation (augmentation=1). + + Returns: + A tuple containing: + - A sorted list of ranked results: [(prediction_tuple, score), ...]. + - A list of invalid rates for each beam position. + """ + rank = {} + highest_pos = {} + invalid_rates = [0] * beam_size + + if raw: + # No test augmentation, len(prediction_group) is 1 + assert len(prediction_group) == 1, 'Raw mode requires augmentation=1' + aug_predictions = prediction_group[0] + for k in range(len(aug_predictions)): + pred_tuple = aug_predictions[k] + if not pred_tuple or not pred_tuple[0]: + invalid_rates[k] += 1 + continue + # Use rank as score for raw mode, lower is better + rank[pred_tuple] = 1 / (score_alpha * k + 1) + else: + # With test augmentation + for aug_predictions in prediction_group: + valid_k = [] # Store valid (prediction_tuple, original_beam_index) + for k, pred_tuple in enumerate(aug_predictions): + if pred_tuple and pred_tuple[0]: + valid_k.append((pred_tuple, k)) + else: + invalid_rates[k] += 1 + + # Deduplicate predictions within this augmentation run + seen = set() + deduped_preds = [] + for pred_tuple, k in valid_k: + if pred_tuple not in seen: + seen.add(pred_tuple) + deduped_preds.append((pred_tuple, k)) + + # Update ranks and highest positions + for k, (pred_tuple, _) in enumerate(deduped_preds): + score = 1 / (score_alpha * k + 1) + rank[pred_tuple] = rank.get(pred_tuple, 0) + score + highest_pos[pred_tuple] = min( + k, highest_pos.get(pred_tuple, float('inf'))) + + # Combine scores for final ranking + # The -1e8 term heavily penalizes lower ranks, + # ensuring highest position is prioritized + final_ranked_list = [] + if not raw: + for key, score in rank.items(): + final_ranked_list.append((key, score + highest_pos[key] * -1e8)) + else: + for key, score in rank.items(): + final_ranked_list.append((key, score)) + + final_ranked_list.sort(key=lambda x: x[1], reverse=True) + return final_ranked_list[:n_best], invalid_rates + + +# ---------------------------------------------------------------------- +# 定义 Postprocessor (后处理器) +# ---------------------------------------------------------------------- + + +@TEXT_POSTPROCESSORS.register_module() +def Retrosynthesis_postprocess(text: Union[str, None]) -> str: + """ + 从模型的原始输出中提取SMILES字符串。 + + 此函数会查找并返回被 标签包裹的内容。 + """ + # 检查输入是否为字符串,如果不是则返回空字符串,以提高代码健壮性 + if not isinstance(text, str): + return '' + + # 删除 标签及其内容 + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + + # 使用正则表达式搜索SMILES标签内的内容 + # re.search() 会查找字符串中首次出现该模式的位置 + # (.*?) 是一个非贪婪捕获组,用于捕获两个标签之间的所有字符 + # re.DOTALL 标志让 '.' 可以匹配包括换行符在内的任意字符 + matches = re.findall(r'(.*?)', text, re.DOTALL) + + if matches: + # 如果找到匹配项,group(1)会返回第一个捕获组的内容 + # .strip() 用于去除捕获内容前后可能存在的多余空格或换行 + return matches[-1].strip() + else: + # 如果没有找到匹配的模式,返回一个空字符串 + return '' + + +# ---------------------------------------------------------------------- +# 定义 Evaluator (评估器) - 这是修改的核心 +# ---------------------------------------------------------------------- + + +class RetrosynthesisEvaluator(BaseEvaluator): + """ + Evaluator for retrosynthesis models. It calculates Top-K accuracy and + Max-Fragment accuracy based on SMILES string comparisons. + """ + + def __init__(self, + beam_size=10, + n_best=10, + augmentation=1, + score_alpha=1.0, + synthon=False, + process_number=None): + super().__init__() + self.beam_size = beam_size + self.n_best = n_best + self.augmentation = augmentation + self.score_alpha = score_alpha + self.synthon = synthon + self.process_number = process_number if process_number is not None \ + else multiprocessing.cpu_count() + print(f'Evaluator initialized with: beam_size={beam_size},' + f' n_best={n_best}, augmentation={augmentation},' + f' processes={self.process_number}') + + def score(self, predictions, references): + """ + Calculates retrosynthesis prediction accuracy. + + Args: + predictions (list): A flat list of predicted SMILES strings. + Shape: [data_size * augmentation * beam_size]. + references (list): A list of ground truth SMILES strings. + Shape: [data_size]. + + Returns: + dict: A dictionary containing evaluation metrics. + """ + # flat predictions -> 1D + print(f'len of predictions: {len(predictions)}') + print(f'predictions[0]: {predictions[0]}') + if isinstance(predictions, list): + # Ensure predictions are a flat list + if isinstance(predictions[0], list): + predictions = [x for y in predictions for x in y] + else: + pass + + # print(f"predictions = {predictions} \nreferences = {references}") + data_size = len(references) + expected_preds_len = data_size * self.augmentation * self.beam_size + if len(predictions) != expected_preds_len: + return { + 'error': + f'Length of predictions ({len(predictions)})' + f' does not match expected length ({expected_preds_len})' + } + + print('Canonicalizing predictions and references...') + # Create a partial function for multiprocessing + map_func = partial(canonicalize_smiles_clear_map, + synthon=self.synthon, + return_max_frag=True) + + with multiprocessing.Pool(self.process_number) as pool: + can_predictions = list( + tqdm(pool.imap(map_func, predictions), + total=len(predictions), + desc='Canonicalizing Predictions')) + can_references = list( + tqdm(pool.imap(map_func, references), + total=len(references), + desc='Canonicalizing References')) + + # Reshape the flat predictions list into a 3D list: + # data_size x augmentation x beam_size + predictions_reshaped = [[] for _ in range(data_size)] + for i in range(data_size): + for j in range(self.augmentation): + start_idx = (i * self.augmentation + j) * self.beam_size + end_idx = start_idx + self.beam_size + predictions_reshaped[i].append( + can_predictions[start_idx:end_idx]) + + # Initialize metric counters + accuracy = [0] * self.n_best + max_frag_accuracy = [0] * self.n_best + total_invalid_rates = [0] * self.beam_size + + print('Computing ranks and accuracy...') + is_raw_mode = (self.augmentation == 1) + + for i in tqdm(range(data_size), desc='Evaluating Samples'): + prediction_group = predictions_reshaped[i] + target_smi_tuple = can_references[i] + + # Skip evaluation for this sample if the ground truth is invalid + if not target_smi_tuple or not target_smi_tuple[0]: + continue + + ranked_results, invalid_rate = compute_rank( + prediction_group, + beam_size=self.beam_size, + n_best=self.n_best, + score_alpha=self.score_alpha, + raw=is_raw_mode) + + # Aggregate invalid rates + for j in range(len(invalid_rate)): + total_invalid_rates[j] += invalid_rate[j] + + # Check for full molecule match + found_match = False + for j, (pred_tuple, _) in enumerate(ranked_results): + if not found_match and pred_tuple[0] == target_smi_tuple[0]: + for k in range(j, self.n_best): + accuracy[k] += 1 + found_match = True # Ensure we only count the first match + + # Check for max fragment match + found_frag_match = False + for j, (pred_tuple, _) in enumerate(ranked_results): + # Ensure max fragment is not empty before comparing + if not found_frag_match and pred_tuple[1] and pred_tuple[ + 1] == target_smi_tuple[1]: + for k in range(j, self.n_best): + max_frag_accuracy[k] += 1 + found_frag_match = True + + # Calculate final results + results = {} + # Usually, Top-1, 3, 5, 10 are reported + for i in [k - 1 for k in [1, 3, 5, 10] if k <= self.n_best]: + k = i + 1 + results[f'Top-{k} Accuracy'] = accuracy[i] / data_size * 100 + results[f'Top-{k} MaxFrag Accuracy'] = max_frag_accuracy[ + i] / data_size * 100 + + # Report the invalid rate at the first beam position + if self.beam_size > 0: + total_predictions_at_beam1 = data_size * self.augmentation + results['Invalid SMILES Rate (at beam 1)'] = ( + total_invalid_rates[0] / total_predictions_at_beam1 * 100) \ + if total_predictions_at_beam1 > 0 else 0 + + return results + + +# Example Usage +if __name__ == '__main__': + # --- Mock Data Generation --- + # This simulates the kind of data the evaluator would receive. + + # Configuration + BEAM_SIZE = 5 + N_BEST = 5 + AUGMENTATION = 3 # Use > 1 to test augmentation logic + DATA_SIZE = 100 + + # Ground truth molecules (references) + mock_references = [ + 'CCO.CN', # Correct: CCO is largest fragment + 'c1ccccc1CC(=O)O', # Correct + 'INVALID_SMILES', # An invalid reference SMILES + 'CC(C)C(=O)N[C@@H](C)C(=O)O' # Chiral molecule + ] * (DATA_SIZE // 4) + + # Simulated model predictions (a flat list) + mock_predictions = [] + for i in range(DATA_SIZE): + target = mock_references[i] + for _ in range(AUGMENTATION): + # For each augmentation, create a beam of predictions + beam = [] + # Make the first beam prediction correct for 20% of cases + if i % 5 == 0: + beam.append(target) + else: + beam.append('CC(C)=O') # A common incorrect prediction + + # Add some other variations and invalid SMILES + beam.append('c1cnccc1') + beam.append('completely_invalid') # Invalid + # Add a prediction that only matches the largest fragment + beam.append('CCO') + # Fill the rest of the beam + beam.extend(['C'] * (BEAM_SIZE - len(beam))) + + mock_predictions.extend(beam) + + print(f'Generated {len(mock_predictions)} ' + f'predictions for {len(mock_references)} references.') + + # --- Evaluation --- + evaluator = RetrosynthesisEvaluator( + beam_size=BEAM_SIZE, + n_best=N_BEST, + augmentation=AUGMENTATION, + process_number=4 # Use 4 cores for the example + ) + + results = evaluator.score(mock_predictions, mock_references) + + # --- Print Results --- + print('\n--- Evaluation Results ---') + for key, value in results.items(): + print(f'{key}: {value:.2f}%') + print('--------------------------\n') + + # --- Test RAW mode (no augmentation) --- + print('Testing RAW mode (augmentation=1)...') + evaluator_raw = RetrosynthesisEvaluator( + beam_size=BEAM_SIZE, + n_best=N_BEST, + augmentation=1, # RAW mode + process_number=4) + # Select only the first "augmentation" set of predictions + mock_predictions_raw = [] + for i in range(DATA_SIZE): + start_idx = i * AUGMENTATION * BEAM_SIZE + end_idx = start_idx + BEAM_SIZE + mock_predictions_raw.extend(mock_predictions[start_idx:end_idx]) + + results_raw = evaluator_raw.score(mock_predictions_raw, mock_references) + + print('\n--- RAW Mode Evaluation Results ---') + for key, value in results_raw.items(): + print(f'{key}: {value:.2f}%') + print('---------------------------------\n') diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/__input__.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/__input__.py new file mode 100644 index 000000000..40f570c65 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/__input__.py @@ -0,0 +1 @@ +import smiles_canonicalization # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/chat_generation.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/chat_generation.py new file mode 100644 index 000000000..d64618ff0 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/chat_generation.py @@ -0,0 +1,12 @@ +def generate_chat(input_text, output_text=None, prefix_chat=None): + chat = [ + { + 'role': 'user', + 'content': input_text + }, + ] + if output_text is not None: + chat.append({'role': 'assistant', 'content': output_text}) + if prefix_chat is not None: + chat = prefix_chat + chat + return chat diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/core_tagger.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/core_tagger.py new file mode 100644 index 000000000..2ed0bd4ca --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/core_tagger.py @@ -0,0 +1,195 @@ +def find_sub_sequence(whole, sub): + assert isinstance(whole, list) + assert isinstance(sub, list) + len_whole = len(whole) + len_sub = len(sub) + assert len_whole > 0 + assert len_sub > 0 + + s = 0 + while True: + s_whole = whole[s:] + try: + k_pos = s_whole.index(sub[0]) + except ValueError: + return -1 + + fail = False + for i in range(1, len_sub): + try: + if s_whole[k_pos + i] != sub[i]: + fail = True + break + except IndexError: + return -1 + if fail: + s = s + k_pos + 1 + continue + else: + return s + k_pos + + +class CoreTagger(object): + + def __init__(self, + tokenizer, + core_tags_as_special_tokens=False, + include_tags=True): + self.tokenizer = tokenizer + if core_tags_as_special_tokens: + raise NotImplementedError + self.core_tags_as_special_tokens = core_tags_as_special_tokens + if not include_tags: + raise NotImplementedError + self.include_tags = include_tags + + self.left_tag_to_id = {} + self.right_tag_to_id = {} + + def generate_mask(self, token_ids, output_begin, sample): + mask = [0] * len(token_ids) + left_tag, right_tag = sample['output_core_tag_left'], sample[ + 'output_core_tag_right'] + if left_tag not in self.left_tag_to_id: + if left_tag is None: + left_token_ids = None + else: + left_token_ids = self.tokenizer( + left_tag, + add_special_tokens=False, + return_attention_mask=False)['input_ids'] + self.left_tag_to_id[left_tag] = left_token_ids + else: + left_token_ids = self.left_tag_to_id[left_tag] + if right_tag not in self.right_tag_to_id: + if right_tag is None: + right_token_ids = None + else: + right_token_ids = self.tokenizer( + right_tag, + add_special_tokens=False, + return_attention_mask=False)['input_ids'] + self.right_tag_to_id[right_tag] = right_token_ids + else: + right_token_ids = self.right_tag_to_id[right_tag] + + output_token_ids = token_ids[output_begin:] + if left_token_ids is None: + left_position = output_begin + elif len(output_token_ids) == 0: + left_position = None + else: + left_position = find_sub_sequence(output_token_ids, + left_token_ids) + output_begin + if left_position == -1: + left_position = None + + if left_position is None: + return mask + + if right_token_ids is None: + right_position = len(token_ids) + if token_ids[-1] == self.tokenizer.eos_token_id: + right_position -= 1 + else: + right_position = find_sub_sequence(output_token_ids, + right_token_ids) + output_begin + if right_position == -1: + right_position = len(token_ids) + if token_ids[-1] == self.tokenizer.eos_token_id: + right_position -= 1 + else: + right_position = min(right_position + len(right_token_ids), + len(token_ids)) + + for idx in range(left_position, right_position): + mask[idx] = 1 + + return mask + + +class CoreTaggerGeneral(object): + + def __init__(self, + tokenizer, + core_tags_as_special_tokens=False, + include_tags=True): + self.tokenizer = tokenizer + if core_tags_as_special_tokens: + raise NotImplementedError + self.core_tags_as_special_tokens = core_tags_as_special_tokens + if not include_tags: + raise NotImplementedError + self.include_tags = include_tags + + self.left_tag_to_id = {} + self.right_tag_to_id = {} + + def generate_mask(self, token_ids, prompt_mask, sample): + mask = [0] * len(token_ids) + left_tag, right_tag = sample['output_core_tag_left'], sample[ + 'output_core_tag_right'] + if left_tag not in self.left_tag_to_id: + if left_tag is None: + left_token_ids = None + else: + left_token_ids = self.tokenizer( + left_tag, + add_special_tokens=False, + return_attention_mask=False)['input_ids'] + self.left_tag_to_id[left_tag] = left_token_ids + else: + left_token_ids = self.left_tag_to_id[left_tag] + if right_tag not in self.right_tag_to_id: + if right_tag is None: + right_token_ids = None + else: + right_token_ids = self.tokenizer( + right_tag, + add_special_tokens=False, + return_attention_mask=False)['input_ids'] + self.right_tag_to_id[right_tag] = right_token_ids + else: + right_token_ids = self.right_tag_to_id[right_tag] + + cur_ = 0 + for idx in range(len(token_ids)): + if prompt_mask[idx] == 1 or token_ids[ + idx] == self.tokenizer.bos_token_id: + cur_ = 0 + continue + + if left_token_ids is None: + match_left = True + else: + match_left = True + try: + for offset in range(len(left_token_ids)): + if token_ids[idx + offset] != left_token_ids[offset]: + match_left = False + break + except IndexError: + match_left = False + + if match_left: + cur_ = 1 + + mask[idx] = cur_ + + if right_token_ids is None: + continue + + match_right = True + try: + for offset in range(len(right_token_ids)): + if token_ids[idx - len(right_token_ids) + + offset] != right_token_ids[offset]: + match_right = False + break + except IndexError: + match_right = False + + if match_right: + cur_ = 0 + + return mask diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/general_prompter.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/general_prompter.py new file mode 100644 index 000000000..7f22b32c3 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/general_prompter.py @@ -0,0 +1,35 @@ +def get_chat_content(conversation, tokenize=False): + if tokenize: + raise NotImplementedError + available_roles = ('user', 'assistant') + content = '' + for idx, item in enumerate(conversation): + role = item['role'] + assert role in available_roles, role + if idx % 2 == 0: + assert role == 'user' + content += '' + item_content = '[INST] %s [/INST]' % item['content'] + content += item_content + else: + assert role == 'assistant' + item_content = ' %s' % item['content'] + content += item_content + return content + + +class GeneralPrompter(object): + + def __init__(self, apply_chat_template_func, response_split='[/INST]'): + self.apply_chat_template_func = apply_chat_template_func + self.response_split = response_split + + def generate_prompt(self, chat, tokenize=False, *args, **kargs) -> str: + res = self.apply_chat_template_func(chat, + tokenize=tokenize, + *args, + **kargs) + return res + + def get_response(self, output: str) -> str: + return output.split(self.response_split)[-1].strip() diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/metrics.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/metrics.py new file mode 100644 index 000000000..fc306c25d --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/metrics.py @@ -0,0 +1,685 @@ +# flake8: noqa + +import re +from collections import defaultdict + +import numpy as np +from nltk.translate.bleu_score import corpus_bleu +from nltk.translate.meteor_score import meteor_score + +try: + from rdkit import Chem, DataStructs, RDLogger + from rdkit.Chem import AllChem, MACCSkeys +except Exception: + Chem, DataStructs, RDLogger, AllChem, MACCSkeys = None, None, None, None, None + +from rouge_score import rouge_scorer +from sklearn.metrics import (f1_score, matthews_corrcoef, precision_score, + recall_score, roc_auc_score) +from tqdm.auto import tqdm +from transformers import BertTokenizerFast + +from .smiles_canonicalization import (canonicalize_molecule_smiles, + get_molecule_id) + +# RDLogger.DisableLog('rdApp.*') + + +def convert_smiles_list_into_mol_list(smiles_list, + raise_error_when_error=False): + mol_list = [] + no_answer_labels = [] + invalid_labels = [] + for smiles in smiles_list: + if smiles == '': + mol = 'NA' + no_answer_labels.append(True) + if raise_error_when_error: + raise ValueError('SMILES is empty.') + else: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + mol = 'INVALID' + invalid_labels.append(True) + if raise_error_when_error: + raise ValueError('SMILES is not valid: %s' % smiles) + mol_list.append(mol) + + no_answer_labels = np.array(no_answer_labels) + invalid_labels = np.arange(invalid_labels) + + return mol_list, no_answer_labels, invalid_labels + + +def judge_exact_match(pred_can_smiles_list, gold_can_smiles_list): + assert len(pred_can_smiles_list) == len(gold_can_smiles_list) + exact_match_labels = [] + for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list, + gold_can_smiles_list): + if pred_smiles is None or pred_smiles.strip() == '': + exact_match_labels.append(False) + continue + pred_smiles_inchi = get_molecule_id(pred_smiles) + sample_exact_match = False + for gold_smiles in gold_smiles_list: + assert gold_smiles is not None + gold_smiles_inchi = get_molecule_id(gold_smiles) + if pred_smiles_inchi == gold_smiles_inchi: + sample_exact_match = True + break + exact_match_labels.append(sample_exact_match) + return np.array(exact_match_labels) + + +def calculate_fingerprint_similarity(pred_mol_list, + gold_mols_list, + morgan_r=2): + assert len(pred_mol_list) == len(gold_mols_list) + MACCS_sims = [] + morgan_sims = [] + RDK_sims = [] + for pred_mol, gold_mol_list in zip(pred_mol_list, gold_mols_list): + if pred_mol is None or type(pred_mol) == str: + raise ValueError(type(pred_mol)) + tmp_MACCS, tmp_RDK, tmp_morgan = 0, 0, 0 + for gold_mol in gold_mol_list: + tmp_MACCS = max( + tmp_MACCS, + DataStructs.FingerprintSimilarity( + MACCSkeys.GenMACCSKeys(gold_mol), + MACCSkeys.GenMACCSKeys(pred_mol), + metric=DataStructs.TanimotoSimilarity)) + tmp_RDK = max( + tmp_RDK, + DataStructs.FingerprintSimilarity( + Chem.RDKFingerprint(gold_mol), + Chem.RDKFingerprint(pred_mol), + metric=DataStructs.TanimotoSimilarity)) + tmp_morgan = max( + tmp_morgan, + DataStructs.TanimotoSimilarity( + AllChem.GetMorganFingerprint(gold_mol, morgan_r), + AllChem.GetMorganFingerprint(pred_mol, morgan_r))) + MACCS_sims.append(tmp_MACCS) + RDK_sims.append(tmp_RDK) + morgan_sims.append(tmp_morgan) + maccs_sims_score = np.mean(MACCS_sims) + rdk_sims_score = np.mean(RDK_sims) + morgan_sims_score = np.mean(morgan_sims) + return maccs_sims_score, rdk_sims_score, morgan_sims_score + + +def judge_multiple_match(pred_can_smiles_list, golds_can_smiles_list): + assert len(pred_can_smiles_list) == len(golds_can_smiles_list) + subset_labels = [] + intersection_labels = [] + for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list, + golds_can_smiles_list): + if pred_smiles is None: + subset_labels.append(False) + intersection_labels.append(False) + continue + + pred_ele_set = set() + for smiles in pred_smiles.split('.'): + pred_ele_set.add(get_molecule_id(smiles, remove_duplicate=False)) + + intersection_label = False + subset_label = False + for gold_smiles in gold_smiles_list: + assert gold_smiles is not None + gold_ele_set = set() + for smiles in gold_smiles.split('.'): + gold_ele_set.add( + get_molecule_id(smiles, remove_duplicate=False)) + + if len(pred_ele_set & gold_ele_set) > 0: + intersection_label = True + g_p = gold_ele_set - pred_ele_set + if len(g_p) >= 0 and len(pred_ele_set - gold_ele_set) == 0: + subset_label = True + break + intersection_labels.append(intersection_label) + subset_labels.append(subset_label) + + return intersection_labels, subset_labels + + +def calculate_smiles_metrics(preds_smiles_list, + golds_smiles_list, + metrics=('exact_match', 'fingerprint')): + num_all = len(preds_smiles_list) + assert num_all > 0 + assert num_all == len(golds_smiles_list) + k = len(preds_smiles_list[0]) + + dk_pred_smiles_list_dict = {} + dk_pred_no_answer_labels_dict = {} + dk_pred_invalid_labels_dict = {} + for dk in range(k): + dk_pred_smiles_list_dict[dk] = [] + dk_pred_no_answer_labels_dict[dk] = [] + dk_pred_invalid_labels_dict[dk] = [] + for pred_smiles_list in tqdm(preds_smiles_list): + if pred_smiles_list is None: + for dk in range(k): + dk_pred_no_answer_labels_dict[dk].append(True) + dk_pred_invalid_labels_dict[dk].append(False) + dk_pred_smiles_list_dict[dk].append(None) + continue + assert len(pred_smiles_list) == k + for dk, item in enumerate(pred_smiles_list): + # item = item.strip() + if item == '' or item is None: + item = None + dk_pred_no_answer_labels_dict[dk].append(True) + dk_pred_invalid_labels_dict[dk].append(False) + else: + dk_pred_no_answer_labels_dict[dk].append(False) + item = canonicalize_molecule_smiles(item) + if item is None: + dk_pred_invalid_labels_dict[dk].append(True) + else: + dk_pred_invalid_labels_dict[dk].append(False) + dk_pred_smiles_list_dict[dk].append(item) + + new_list = [] + for gold_smiles_list in tqdm(golds_smiles_list): + sample_gold_smiles_list = [] + for gold in gold_smiles_list: + item = gold.strip() + new_item = canonicalize_molecule_smiles( + item, return_none_for_error=False) + # if new_item is None: + # new_item = item #TODO + # assert new_item is not None, item + sample_gold_smiles_list.append(new_item) + new_list.append(sample_gold_smiles_list) + golds_smiles_list = new_list + + metric_results = {'num_all': num_all} + + tk_pred_no_answer_labels = np.array([True] * num_all) + tk_pred_invalid_labels = np.array([True] * num_all) + for dk in range(k): + dk_no_answer_labels = dk_pred_no_answer_labels_dict[dk] + dk_invalid_labels = dk_pred_invalid_labels_dict[dk] + tk_pred_no_answer_labels = tk_pred_no_answer_labels & \ + dk_no_answer_labels + tk_pred_invalid_labels = tk_pred_invalid_labels & dk_invalid_labels + metric_results['num_t%d_no_answer' % + (dk + 1)] = tk_pred_no_answer_labels.sum().item() + metric_results['num_t%d_invalid' % + (dk + 1)] = tk_pred_invalid_labels.sum().item() + + # d1_no_answer_labels = dk_pred_no_answer_labels_dict[0] + # # print(np.array(d1_no_answer_labels).sum().item()) + # for label, item in zip(d1_no_answer_labels, preds_smiles_list): + # if label: + # print(item) + + for metric in metrics: + if metric == 'exact_match': + tk_exact_match_labels = np.array([False] * num_all) + for dk in range(k): + dk_pred_smiles_list = dk_pred_smiles_list_dict[dk] + dk_exact_match_labels = judge_exact_match( + dk_pred_smiles_list, golds_smiles_list) + tk_exact_match_labels = tk_exact_match_labels | \ + dk_exact_match_labels + metric_results['num_t%d_exact_match' % + (dk + 1)] = tk_exact_match_labels.sum().item() + elif metric == 'fingerprint': + d1_pred_mol_list = [] + gold_mols_list = [] + for pred_smiles, gold_smiles_list, no_answer, invalid in zip( + dk_pred_smiles_list_dict[0], golds_smiles_list, + dk_pred_no_answer_labels_dict[0], + dk_pred_invalid_labels_dict[0]): + if pred_smiles is None or pred_smiles.strip( + ) == '' or no_answer is True or invalid is True: + continue + pred_mol = Chem.MolFromSmiles(pred_smiles) + if pred_mol is None: # TODO + continue + assert pred_mol is not None, pred_smiles + gold_mol_list = [] + for gold_smiles in gold_smiles_list: + gold_mol = Chem.MolFromSmiles(gold_smiles) + # if gold_mol is None: + # continue # TODO + assert gold_mol is not None, gold_smiles + gold_mol_list.append(gold_mol) + # if len(gold_mol_list) == 0: + # continue # TODO + d1_pred_mol_list.append(pred_mol) + gold_mols_list.append(gold_mol_list) + maccs_sims_score, rdk_sims_score, morgan_sims_score = \ + calculate_fingerprint_similarity( + d1_pred_mol_list, gold_mols_list) + metric_results['t1_maccs_fps'] = maccs_sims_score + metric_results['t1_rdk_fps'] = rdk_sims_score + metric_results['t1_morgan_fps'] = morgan_sims_score + elif metric == 'multiple_match': + tk_intersection_labels = np.array([False] * num_all) + tk_subset_labels = np.array([False] * num_all) + for dk in range(k): + dk_intersection_labels, dk_subset_labels = \ + judge_multiple_match( + dk_pred_smiles_list_dict[dk], golds_smiles_list) + tk_intersection_labels = tk_intersection_labels | \ + dk_intersection_labels + tk_subset_labels = tk_subset_labels | dk_subset_labels + metric_results['num_t%d_subset' % + (dk + 1)] = tk_intersection_labels.sum().item() + metric_results['num_t%d_intersection' % + (dk + 1)] = tk_intersection_labels.sum().item() + else: + raise ValueError(metric) + + return metric_results + + +def judge_string_exact_match(pred_string_list, golds_string_list): + exact_match_labels = [] + for pred_string, gold_string_list in zip(pred_string_list, + golds_string_list): + exact_match = False + for gold_string in gold_string_list: + if pred_string == gold_string: + exact_match = True + break + exact_match_labels.append(exact_match) + return np.array(exact_match_labels) + + +def judge_string_split_match(pred_string_list, + golds_string_list, + separator=';'): + exact_match_labels = [] + for pred_string, gold_string_list in zip(pred_string_list, + golds_string_list): + pred_item = tuple(sorted(pred_string.split(separator))) + exact_match = False + for gold_string in gold_string_list: + gold_item = tuple(sorted(gold_string.split(separator))) + if pred_item == gold_item: + exact_match = True + break + exact_match_labels.append(exact_match) + return np.array(exact_match_labels) + + +def parse_molecule(molecular_formula): + valid = re.match(r'([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula) + if valid is None: + raise ValueError("Molecular formula \"%s\" is not valid." % + molecular_formula) + + stack = [defaultdict(int)] + + def _parse_formula(formula, _stack): + + # Set remainder equal to 'None' + r = None + + # Regular expression matching for each of the three cases: + atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula) + opening = re.match(r'[\(\[\{]', formula) + closing = re.match(r'[\)\]\}](\d+)?', formula) + + # If atom is identified: + if atom: + r = formula[len(atom.group()):] + _stack[-1][atom.group(1)] += int(atom.group(2) or 1) + + # If opening brackets encountered: + elif opening: + # this sets the remainder equal + # to everything after the opening brackets + r = formula[len(opening.group()):] + _stack.append(defaultdict(int)) + + # If closing brackets encountered: + elif closing: + r = formula[len(closing.group()):] + # this sets the remainder equal to + # everything after the closing brackets + for (k, v) in _stack.pop().items(): + _stack[-1][k] += v * int( + closing.group(1) + # v times amount of molecule k, + # depending on nesting + or 1) + + # If anything remains, + # process remainders recursively as nested formulas: + if r: + _parse_formula(r, _stack) + + return dict(_stack[0]) + + result = _parse_formula(molecular_formula, stack) + + charge = re.search(r'[\+\-]\d*', molecular_formula) + if charge is not None: + charge_str = charge.group() + charge_type = charge_str[0] + if len(charge_str) == 1: + charge_num = 1 + else: + charge_num = int(charge_str[1:]) + result[charge_type] = charge_num + + return result + + +def count_element_match(pred_formula_list, golds_formula_list): + assert len(pred_formula_list) == len(golds_formula_list) + ele_match_labels = [] + ele_invalid_labels = [] + for pred_formula, gold_formula_list in zip(pred_formula_list, + golds_formula_list): + if pred_formula == '' or pred_formula is None: + ele_invalid_labels.append(False) + ele_match_labels.append(False) + continue + try: + pred_ele = parse_molecule(pred_formula) + except KeyboardInterrupt: + raise + except Exception: + # print(pred_formula) + # print('=====') + ele_invalid_labels.append(True) + ele_match_labels.append(False) + continue + ele_invalid_labels.append(False) + ele_match = False + for gold_formula in gold_formula_list: + gold_ele = parse_molecule(gold_formula) + if pred_ele == gold_ele: + ele_match = True + break + ele_match_labels.append(ele_match) + return ele_match_labels, ele_invalid_labels + + +def calculate_formula_metrics(preds_formula_list, + golds_formula_list, + metrics=('element_match', )): + """ + Calculate metrics for molecular formula. + Here we use element_match (equals to exact_match + used in our paper) by default, + which compares the atom numbers and ignore the orders. + For example, C5H8 == H8C5. + """ + num_all = len(preds_formula_list) + assert len(preds_formula_list) == len(golds_formula_list) + try: + k = len(preds_formula_list[0]) + except IndexError: + print(preds_formula_list) + raise + dk_pred_formula_list_dict = dict() + for dk in range(k): + dk_pred_formula_list_dict[dk] = [] + for sample_formula_list in preds_formula_list: + if sample_formula_list is None: + for dk in range(k): + dk_pred_formula_list_dict[dk].append('') + continue + assert len(sample_formula_list) == k + for dk in range(k): + item = sample_formula_list[dk] + dk_pred_formula_list_dict[dk].append(item) + golds_formula_list = [[small_item.strip() for small_item in item] + for item in golds_formula_list] + new_golds_formula_list = [] + for item in golds_formula_list: + new_item = [] + for small_item in item: + small_item = small_item.strip() + assert small_item != '' + new_item.append(small_item) + new_golds_formula_list.append(new_item) + golds_formula_list = new_golds_formula_list + + metric_results = {'num_all': num_all} + + tk_no_answer_labels = np.array([True] * num_all) + for dk in range(k): + dk_pred_formula_list = dk_pred_formula_list_dict[dk] + dk_no_answer_labels = [] + for item in dk_pred_formula_list: + if item == '' or item is None: + dk_no_answer_labels.append(True) + else: + dk_no_answer_labels.append(False) + dk_no_answer_labels = np.array(dk_no_answer_labels) + tk_no_answer_labels = tk_no_answer_labels & dk_no_answer_labels + metric_results['num_t%d_no_answer' % + (dk + 1)] = tk_no_answer_labels.sum().item() + + for metric in metrics: + if metric == 'exact_match': + tk_exact_match_labels = np.array([False] * num_all) + for dk in range(k): + dk_pred_formula_list = dk_pred_formula_list_dict[dk] + dk_exact_match_labels = judge_string_exact_match( + dk_pred_formula_list, golds_formula_list) + tk_exact_match_labels = tk_exact_match_labels | \ + dk_exact_match_labels + metric_results['num_t%d_exact_match' % + (dk + 1)] = tk_exact_match_labels.sum().item() + elif metric == 'element_match': + tk_ele_match_labels = np.array([False] * num_all) + tk_formula_invalid_labels = np.array([True] * num_all) + for dk in range(k): + dk_pred_formula_list = dk_pred_formula_list_dict[dk] + dk_ele_match_labels, dk_formula_invalid_labels = \ + count_element_match( + dk_pred_formula_list, golds_formula_list) + tk_ele_match_labels = tk_ele_match_labels | dk_ele_match_labels + tk_formula_invalid_labels = tk_formula_invalid_labels & \ + dk_formula_invalid_labels + metric_results['num_t%d_ele_match' % + (dk + 1)] = tk_ele_match_labels.sum().item() + metric_results['num_t%d_formula_invalid' % + (dk + + 1)] = tk_formula_invalid_labels.sum().item() + elif metric == 'split_match': + tk_exact_match_labels = np.array([False] * num_all) + for dk in range(k): + dk_pred_formula_list = dk_pred_formula_list_dict[dk] + dk_exact_match_labels = judge_string_split_match( + dk_pred_formula_list, golds_formula_list) + tk_exact_match_labels = tk_exact_match_labels | \ + dk_exact_match_labels + metric_results['num_t%d_split_match' % + (dk + 1)] = tk_exact_match_labels.sum().item() + else: + raise ValueError(metric) + + return metric_results + + +def calculate_text_metrics(pred_text_list, + gold_text_list, + text_model='allenai/scibert_scivocab_uncased', + text_trunc_length=512): + assert len(pred_text_list) == len(gold_text_list) + pred_text_list = [(item[0].strip() if item is not None else '') + for item in pred_text_list] + gold_text_list = [item[0].strip() for item in gold_text_list] + + num_no_answer = 0 + for pred_formula in pred_text_list: + if pred_formula == '': + num_no_answer += 1 + + text_tokenizer = BertTokenizerFast.from_pretrained(text_model) + + meteor_scores = [] + + references = [] + hypotheses = [] + + for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)): + if out == '': + continue + + gt_tokens = text_tokenizer.tokenize(gt, + truncation=True, + max_length=text_trunc_length, + padding='max_length') + gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens)) + gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens)) + gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens)) + + out_tokens = text_tokenizer.tokenize(out, + truncation=True, + max_length=text_trunc_length, + padding='max_length') + out_tokens = list(filter(('[PAD]').__ne__, out_tokens)) + out_tokens = list(filter(('[CLS]').__ne__, out_tokens)) + out_tokens = list(filter(('[SEP]').__ne__, out_tokens)) + + references.append([gt_tokens]) + hypotheses.append(out_tokens) + + mscore = meteor_score([gt_tokens], out_tokens) + meteor_scores.append(mscore) + + bleu2 = corpus_bleu(references, hypotheses, weights=(.5, .5)) + bleu4 = corpus_bleu(references, hypotheses, weights=(.25, .25, .25, .25)) + + _meteor_score = np.mean(meteor_scores) + + scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL']) + + rouge_scores = [] + + references = [] + hypotheses = [] + + for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)): + if out == '': + continue + + rs = scorer.score(out, gt) + rouge_scores.append(rs) + + rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) + rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) + rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) + + result = { + 'num_all': len(pred_text_list), + 'num_no_answer': num_no_answer, + 'bleu2': bleu2, + 'bleu4': bleu4, + 'rouge_1': rouge_1, + 'rouge_2': rouge_2, + 'rouge_l': rouge_l, + 'meteor_score': _meteor_score, + } + + return result + + +def calculate_number_metrics(pred_text_list, gold_text_list): + assert len(pred_text_list) == len(gold_text_list) + num_all = len(pred_text_list) + metrics = {} + metrics['num_all'] = num_all + num_no_answer = 0 + num_invalid = 0 + new_pred_text_list, new_gold_text_list = [], [] + for (pred_item, gold_item) in zip(pred_text_list, gold_text_list): + if pred_item is None: + num_no_answer += 1 + continue + assert len(pred_item) == 1 + assert len(gold_item) == 1 + pred_item = pred_item[0] + gold_item = gold_item[0] + if pred_item == '': + num_no_answer += 1 + continue + try: + pred_item = float(pred_item) + except (SyntaxError, ValueError): + # print("\"%s\"" % pred_item) + num_invalid += 1 + continue + gold_item = float(gold_item) + new_pred_text_list.append(pred_item) + new_gold_text_list.append(gold_item) + + new_pred_text_list = np.array(new_pred_text_list) + new_gold_text_list = np.array(new_gold_text_list) + score = np.sqrt(((new_pred_text_list - new_gold_text_list)**2).mean()) + + metrics['num_no_answer'] = num_no_answer + metrics['num_invalid'] = num_invalid + metrics['RMSE'] = score + + return metrics + + +def calculate_boolean_metrics(pred_text_list, gold_text_list): + assert len(pred_text_list) == len(gold_text_list) + num_all = len(pred_text_list) + metrics = {} + metrics['num_all'] = num_all + num_no_answer = 0 + num_invalid = 0 + num_correct = 0 + new_pred_text_list, new_gold_text_list = [], [] + for (pred_item, gold_item) in zip(pred_text_list, gold_text_list): + if pred_item is None or pred_item == '': + num_no_answer += 1 + continue + assert len(pred_item) == 1 + assert len(gold_item) == 1 + pred_item = pred_item[0].strip().lower() + gold_item = gold_item[0].strip().lower() + if pred_item == '': + num_no_answer += 1 + continue + if pred_item not in ('yes', 'no'): + num_invalid += 1 + continue + pred_item = 1 if pred_item == 'yes' else 0 + gold_item = 1 if gold_item == 'yes' else 0 + new_pred_text_list.append(pred_item) + new_gold_text_list.append(gold_item) + if gold_item == pred_item: + num_correct += 1 + + metrics['num_no_answer'] = num_no_answer + metrics['num_invalid'] = num_invalid + metrics['num_correct'] = num_correct + + # return metrics + + new_gold_text_list = np.array(new_gold_text_list) + new_pred_text_list = np.array(new_pred_text_list) + + macro_roc_auc_score = roc_auc_score(new_gold_text_list, new_pred_text_list) + f1 = f1_score(new_gold_text_list, new_pred_text_list) + metrics['roc_auc_score'] = macro_roc_auc_score + metrics['precision'] = precision_score(new_gold_text_list, + new_pred_text_list) + metrics['recall'] = recall_score(new_gold_text_list, new_pred_text_list) + metrics['f1_score'] = f1 + + no_mask = (new_gold_text_list == 0) + new_gold_text_list[no_mask] = -1 + no_mask = (new_pred_text_list == 0) + new_pred_text_list[no_mask] = -1 + metrics['mcc'] = matthews_corrcoef(new_gold_text_list, new_pred_text_list) + + return metrics diff --git a/opencompass/datasets/SciReasoner/LLM4Chem/utils/smiles_canonicalization.py b/opencompass/datasets/SciReasoner/LLM4Chem/utils/smiles_canonicalization.py new file mode 100644 index 000000000..6401bd644 --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Chem/utils/smiles_canonicalization.py @@ -0,0 +1,189 @@ +try: + from rdkit import Chem, RDLogger + from rdkit.Chem.AllChem import AssignStereochemistry +except Exception: + Chem, RDLogger, AssignStereochemistry = None, None, None + +# RDLogger.DisableLog('rdApp.*') + + +def canonicalize(smiles, isomeric=False, canonical=True, kekulize=False): + # When canonicalizing a SMILES string, we typically want to + # run Chem.RemoveHs(mol), but this will try to kekulize the mol + # which is not required for canonical SMILES. Instead, we make a + # copy of the mol retaining only the information we desire + # (not explicit Hs) + # Then, we sanitize the mol without kekulization. + # copy_atom and copy_edit_mol + # Are used to create this clean copy of the mol. + def copy_atom(atom): + new_atom = Chem.Atom(atom.GetSymbol()) + new_atom.SetFormalCharge(atom.GetFormalCharge()) + if atom.GetIsAromatic() and atom.GetNoImplicit(): + new_atom.SetNumExplicitHs(atom.GetNumExplicitHs()) + # elif atom.GetSymbol() == 'N': + # print(atom.GetSymbol()) + # print(atom.GetImplicitValence()) + # new_atom.SetNumExplicitHs(-atom.GetImplicitValence()) + # elif atom.GetSymbol() == 'S': + # print(atom.GetSymbol()) + # print(atom.GetImplicitValence()) + return new_atom + + def copy_edit_mol(mol): + from rdchiral.chiral import copy_chirality + + new_mol = Chem.RWMol(Chem.MolFromSmiles('')) + for atom in mol.GetAtoms(): + new_atom = copy_atom(atom) + new_mol.AddAtom(new_atom) + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtom().GetIdx() + a2 = bond.GetEndAtom().GetIdx() + bt = bond.GetBondType() + new_mol.AddBond(a1, a2, bt) + new_bond = new_mol.GetBondBetweenAtoms(a1, a2) + new_bond.SetBondDir(bond.GetBondDir()) + new_bond.SetStereo(bond.GetStereo()) + for new_atom in new_mol.GetAtoms(): + atom = mol.GetAtomWithIdx(new_atom.GetIdx()) + copy_chirality(atom, new_atom) + return new_mol + + smiles = smiles.replace(' ', '') + tmp = Chem.MolFromSmiles(smiles, sanitize=False) + tmp.UpdatePropertyCache() + new_mol = copy_edit_mol(tmp) + # Chem.SanitizeMol(new_mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL) + if not kekulize: + Chem.SanitizeMol(new_mol, + sanitizeOps=Chem.SanitizeFlags.SANITIZE_SETAROMATICITY + | Chem.SanitizeFlags.SANITIZE_PROPERTIES + | Chem.SanitizeFlags.SANITIZE_ADJUSTHS, + catchErrors=True) + else: + Chem.SanitizeMol(new_mol, + sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE + | Chem.SanitizeFlags.SANITIZE_PROPERTIES + | Chem.SanitizeFlags.SANITIZE_ADJUSTHS, + catchErrors=True) + + AssignStereochemistry(new_mol, + cleanIt=False, + force=True, + flagPossibleStereoCenters=True) + + new_smiles = Chem.MolToSmiles(new_mol, + isomericSmiles=isomeric, + canonical=canonical) + return new_smiles + + +def canonicalize_molecule_smiles(smiles, + return_none_for_error=True, + skip_mol=False, + sort_things=True, + isomeric=True, + kekulization=True, + allow_empty_part=False): + things = smiles.split('.') + if skip_mol: + new_things = things + else: + new_things = [] + for thing in things: + try: + if thing == '' and not allow_empty_part: + raise ValueError('SMILES contains empty part.') + + mol = Chem.MolFromSmiles(thing) + # print(f"smiles = {thing} mol = {mol}") + if mol is None: + return thing + assert mol is not None + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + thing_smiles = Chem.MolToSmiles(mol, + kekuleSmiles=False, + isomericSmiles=isomeric) + thing_smiles = Chem.MolFromSmiles(thing_smiles) + thing_smiles = Chem.MolToSmiles(thing_smiles, + kekuleSmiles=False, + isomericSmiles=isomeric) + thing_smiles = Chem.MolFromSmiles(thing_smiles) + thing_smiles = Chem.MolToSmiles(thing_smiles, + kekuleSmiles=False, + isomericSmiles=isomeric) + assert thing_smiles is not None + can_in = thing_smiles + can_out = canonicalize(thing_smiles, isomeric=isomeric) + assert can_out is not None, can_in + thing_smiles = can_out + if kekulization: + thing_smiles = keku_mid = Chem.MolFromSmiles(thing_smiles) + assert keku_mid is not None, \ + 'Before can: %s\nAfter can: %s' % ( + can_in, can_out) + thing_smiles = Chem.MolToSmiles(thing_smiles, + kekuleSmiles=True, + isomericSmiles=isomeric) + except KeyboardInterrupt: + raise + except Exception: + if return_none_for_error: + return None + else: + raise + new_things.append(thing_smiles) + if sort_things: + new_things = sorted(new_things) + new_things = '.'.join(new_things) + return new_things + + +def canonicalize_reaction_smiles(smiles, + return_none_for_error=True, + return_segs=False, + skip_mol=False, + sort_things=True, + isomeric=True, + kekulization=True): + segs = smiles.split('>') + assert len(segs) == 3 + new_segs = [] + for seg in segs: + if seg != '': + new_things = canonicalize_molecule_smiles( + seg, + return_none_for_error=return_none_for_error, + skip_mol=skip_mol, + sort_things=sort_things, + isomeric=isomeric, + kekulization=kekulization) + if return_none_for_error and new_things is None: + return None + new_segs.append(new_things) + else: + new_segs.append('') + + if return_segs: + return tuple(new_segs) + + smiles = '>'.join(new_segs) + return smiles + + +def get_molecule_id(smiles, remove_duplicate=True): + if remove_duplicate: + assert ';' not in smiles + all_inchi = set() + for part in smiles.split('.'): + inchi = get_molecule_id(part, remove_duplicate=False) + all_inchi.add(inchi) + all_inchi = tuple(sorted(all_inchi)) + return all_inchi + else: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return '' + return Chem.MolToInchi(mol) diff --git a/opencompass/datasets/SciReasoner/LLM4Mat.py b/opencompass/datasets/SciReasoner/LLM4Mat.py new file mode 100644 index 000000000..65e2397ae --- /dev/null +++ b/opencompass/datasets/SciReasoner/LLM4Mat.py @@ -0,0 +1,217 @@ +# flake8: noqa + +import json +import os +import re +from typing import Union + +import numpy as np +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from sklearn.metrics import (mean_absolute_error, mean_squared_error, + roc_auc_score) + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +@LOAD_DATASET.register_module() +class LLM4MatDataset(BaseDataset): + + @staticmethod + def load(path, + property, + train_path, + test_path, + mini_set=False) -> DatasetDict: + + def load_single_dataset(path, property, num=None): + + # if (hf_hub is True): + # repo_id = path.split('/')[0] + '/' + path.split('/')[1] + # path = path.split(repo_id + '/')[1] + # + # path = hf_hub_download(repo_id, path, repo_type='dataset') + + with open(path, 'r', encoding='utf-8') as f: + raw_data = json.load(f) + if isinstance(raw_data, dict): + raw_data = [raw_data] + + processed = [] + for i, item in enumerate(raw_data): + if not '{' + f'{property} :' in item['output']: + continue + new_item = { + 'input': item['input'], + 'output': item['output'], + } + processed.append(new_item) + if num: + dataset = Dataset.from_list(processed[:num]) + else: + dataset = Dataset.from_list(processed) + return dataset + + path = get_data_path(path) + train_path = os.path.join(path, train_path) + test_path = os.path.join(path, test_path) + + if mini_set: + test_num = 150 + else: + test_num = None + dataset = DatasetDict({ + 'train': + load_single_dataset(train_path, property, num=5), + 'test': + load_single_dataset(test_path, property, num=test_num) + }) + return dataset + + +non_numeric_props_options = { + 'Direct_or_indirect': ['Indirect', 'Direct'], + 'Direct_or_indirect_HSE': ['Indirect', 'Direct'], + 'SOC': [True, False], + 'is_gap_direct': [True, False], + 'is_stable': [True, False], +} + + +def remove_think_tags(text: str) -> str: + if '' not in text: + return text + if '' not in text: + return '' + return re.sub(r'.*?', '', text, flags=re.DOTALL) + + +def extract_strict_value(text: str, property: str) -> str: + text_clean = re.sub(r'^```(?:json)?\s*|\s*```$', + '', + text.strip(), + flags=re.IGNORECASE | re.MULTILINE) + try: + data = json.loads(text_clean) + if property in data: + raw_value = data[property] + if isinstance(raw_value, (int, float)): + return float(raw_value) + if property in non_numeric_props_options: + options = non_numeric_props_options[property] + for opt in options: + if isinstance(opt, bool): + if str(raw_value).lower() == str(opt).lower(): + return str(opt) + elif str(raw_value).lower() == str(opt).lower(): + return opt + return '' + return str(raw_value) + except Exception: + pass + + pattern = rf'\{{[^{{}}]*"?{re.escape(property)}"?\s*:\s*(.*?)\s*\}}' + match = re.search(pattern, text_clean, flags=re.DOTALL | re.IGNORECASE) + if not match: + return '' + raw_value = match.group(1).strip().strip('"') + if property in non_numeric_props_options: + options = non_numeric_props_options[property] + for opt in options: + if isinstance(opt, bool): + if raw_value.lower() == str(opt).lower(): + return str(opt) + elif raw_value.lower() == opt.lower(): + return opt + return '' + try: + return float(raw_value) + except ValueError: + return '' + + +@TEXT_POSTPROCESSORS.register_module() +def LLM4Mat_postprocessor(text: Union[str, None], property): + if text is None or not isinstance(text, str): + return '' + text = text.strip() + text = remove_think_tags(text) + if text == '': + return '' + result = extract_strict_value(text, property) + return result + + +class LLM4Mat_Evaluator(BaseEvaluator): + + def score(self, predictions, references): + is_regression = isinstance( + references[0], + (int, float)) and not isinstance(references[0], bool) + + if is_regression: + y_true = [] + y_pred = [] + total = len(references) + for t, p in zip(references, predictions): + try: + t_val = float(t) + p_val = float(p) + if not (np.isfinite(t_val) and np.isfinite(p_val)): + continue + y_true.append(t_val) + y_pred.append(p_val) + except Exception: + continue + if len(y_true) == 0: + return { + 'total': total, + 'filtered': len(y_true), + 'MAE': None, + 'RMSE': None, + 'MAD': None, + 'MAD/MAE': None + } + mae = mean_absolute_error(y_true, y_pred) + rmse = mean_squared_error(y_true, y_pred, squared=False) + mean_value = np.mean(y_true) + baseline_pred = [mean_value] * len(y_true) + mad = mean_absolute_error(y_true, baseline_pred) + mad_mae_ratio = mad / mae if mae != 0 else None + return { + 'total': total, + 'filtered': len(y_true), + 'MAE': mae, + 'RMSE': rmse, + 'MAD': mad, + 'MAD/MAE': mad_mae_ratio + } + else: + y_true = [] + y_pred = [] + auc = None + try: + for t, p in zip(references, predictions): + if t in ['Null']: + continue + if t in ['Direct', 'True', True]: + y_true.append(1) + elif t in ['Indirect', 'False', False]: + y_true.append(0) + else: + continue + + if p in ['Direct', 'True', True]: + y_pred.append(1) + elif p in ['Indirect', 'False', False]: + y_pred.append(0) + else: + y_true.pop() + continue + auc = roc_auc_score(y_true, y_pred) + except Exception: + pass + return {'AUC': auc} diff --git a/opencompass/datasets/SciReasoner/Mol_Instructions/__init__.py b/opencompass/datasets/SciReasoner/Mol_Instructions/__init__.py new file mode 100644 index 000000000..a96ce16ce --- /dev/null +++ b/opencompass/datasets/SciReasoner/Mol_Instructions/__init__.py @@ -0,0 +1,15 @@ +from .biotext import Mol_Instructions_Dataset_BioText # noqa: F401, F403 +from .biotext import Mol_Instructions_Evaluator_BioText # noqa: F401, F403 +from .biotext import Mol_Instructions_postprocess_BioText # noqa: F401, F403 +from .molecule import Mol_Instructions_Dataset # noqa: F401, F403 +from .molecule import Mol_Instructions_Evaluator_Mol # noqa: F401, F403 +from .molecule import Mol_Instructions_postprocess_Mol # noqa: F401, F403 +from .normalized_SW_score import normalized_smith_waterman # noqa: F401, F403 +from .protein import \ + Mol_Instructions_Dataset_Protein_Design # noqa: F401, F403 +from .protein import Mol_Instructions_Evaluator_Protein # noqa: F401, F403 +from .protein import \ + Mol_Instructions_Evaluator_Protein_Design # noqa: F401, F403 +from .protein import Mol_Instructions_postprocess_Protein # noqa: F401, F403 +from .protein import \ + Mol_Instructions_postprocess_Protein_Design # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/Mol_Instructions/biotext.py b/opencompass/datasets/SciReasoner/Mol_Instructions/biotext.py new file mode 100644 index 000000000..a3d8ce3e7 --- /dev/null +++ b/opencompass/datasets/SciReasoner/Mol_Instructions/biotext.py @@ -0,0 +1,331 @@ +# flake8: noqa +# molecule task +# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule + +import json +import os +import re +from typing import List + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from sklearn.metrics import precision_recall_fscore_support + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +def CER_calculate_f1_score(true_entities, predicted_entities): + true_entities = set(true_entities.split(', ')) + predicted_entities = set(predicted_entities.split(', ')) + true_positive = len(true_entities & predicted_entities) + precision = true_positive / len(predicted_entities) if len( + predicted_entities) > 0 else 0 + recall = true_positive / len(true_entities) if len( + true_entities) > 0 else 0 + + f1_score = 2 * (precision * recall) / (precision + recall) if ( + precision + recall) > 0 else 0 + # print(true_entities,predicted_entities,f1_score) + return f1_score + + +def calculate_f1_score(true_entities, predicted_entities): + # import pdb;pdb.set_trace() + pattern = r'\(.*?\)' + true_entities = re.findall(pattern, true_entities) + predicted_entities_tmp = re.findall(pattern, predicted_entities) + if not predicted_entities_tmp: + # add () to predicted_entities if it is empty + predicted_entities = f'({predicted_entities})' + predicted_entities_tmp = re.findall(pattern, predicted_entities) + + predicted_entities = [entity.strip() for entity in predicted_entities_tmp] + + true_entities = set(true_entities) + predicted_entities = set(predicted_entities) + true_positive = len(true_entities & predicted_entities) + precision = true_positive / len(predicted_entities) if len( + predicted_entities) > 0 else 0 + recall = true_positive / len(true_entities) if len( + true_entities) > 0 else 0 + + f1_score = 2 * (precision * recall) / (precision + recall) if ( + precision + recall) > 0 else 0 + return f1_score + + +def calculate_accuracy_(predictions, references): + correct_count = 0 + total_count = len(references) + for i, (pred, ref) in enumerate(zip(predictions, references)): + pred = pred[0].lower() + ref = ref[0].lower() + f1_score = calculate_f1_score(ref, pred) + correct_count += f1_score + + return correct_count / total_count + + +def CER_calculate_accuracy_(predictions, references): + correct_count = 0 + total_count = len(references) + for i, (pred, ref) in enumerate(zip(predictions, references)): + pred = pred[0].lower() + ref = ref[0].lower() + f1_score = CER_calculate_f1_score(ref, pred) + # print(f1_score) + correct_count += f1_score + + return correct_count / total_count + + +def ture_or_false_calculate_accuracy_(predictions, references): + x, y, z = 0, 0, 0 + correct_count = 0 + total_count = len(references) + other_answers = 0 + for i, (pred, ref) in enumerate(zip(predictions, references)): + pred = pred[0].lower() + ref = ref[0].lower() + correct_first_word = ref.split(',')[0].strip().lower() + # my_first_word = pred.split(',')[0].strip().lower() + pred = pred.strip().lower() + if 'yes' in pred: + my_first_word = 'yes' + elif 'no' in pred: + my_first_word = 'no' + elif 'maybe' in pred or 'may be' in pred or 'might' in pred: + my_first_word = 'maybe' + else: + other_answers += 1 + my_first_word = 'other' + print(f'Other answer: {pred}, reference: {ref}') + + if correct_first_word == 'no' and my_first_word == 'no': + x += 1 + if correct_first_word == 'no': + y += 1 + if my_first_word == 'no': + z += 1 + if correct_first_word == my_first_word: + correct_count += 1 + accuracy = (correct_count / total_count) * 100 + return accuracy, other_answers + + +def calculate_macro_f1_(predictions, references): + correct_answers = [ + ref[0].split(',')[0].strip().lower() for ref in references + ] + my_answers = [ + pred[0].split(',')[0].strip().lower() for pred in predictions + ] + # Compute precision, recall, and F1-score for each class + precision, recall, f1, _ = precision_recall_fscore_support( + correct_answers, + my_answers, + labels=['yes', 'no', 'maybe'], + average=None) + # Calculate macro F1 by averaging F1-scores for all classes + macro_f1 = sum(f1) / len(f1) + + return macro_f1 + + +def multi_choice_question_calculate_accuracy(question_data): + correct_count = 0 + total_count = len(question_data) + for i, question in enumerate(question_data): + correct_answer = question['output'].split('(')[1].split(')')[0] + my_answer = question['my_output'][0] + if '(A' in question['my_output'] or 'A)' in question[ + 'my_output'] or ' A ' in question['my_output']: + my_answer = 'A' + elif '(B' in question['my_output'] or 'B)' in question[ + 'my_output'] or ' B ' in question['my_output']: + my_answer = 'B' + elif '(C' in question['my_output'] or 'C)' in question[ + 'my_output'] or ' C ' in question['my_output']: + my_answer = 'C' + elif '(D' in question['my_output'] or 'D)' in question[ + 'my_output'] or ' D ' in question['my_output']: + my_answer = 'D' + if correct_answer == my_answer: + correct_count += 1 + accuracy = (correct_count / total_count) * 100 + + return accuracy + + +def multi_choice_question_calculate_accuracy_(predictions, references): + correct_count = 0 + total_count = len(references) + for i, (pred, ref) in enumerate(zip(predictions, references)): + correct_answer = ref[0].split('(')[1].split(')')[0] + my_answer = pred[0] + if '(A' in pred[0] or 'A)' in pred[0] or ' A ' in pred[0]: + my_answer = 'A' + elif '(B' in pred[0] or 'B)' in pred[0] or ' B ' in pred[0]: + my_answer = 'B' + elif '(C' in pred[0] or 'C)' in pred[0] or ' C ' in pred[0]: + my_answer = 'C' + elif '(D' in pred[0] or 'D)' in pred[0] or ' D ' in pred[0]: + my_answer = 'D' + if correct_answer == my_answer: + correct_count += 1 + accuracy = (correct_count / total_count) * 100 + + return accuracy + + +@LOAD_DATASET.register_module() +class Mol_Instructions_Dataset_BioText(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + + if (max_cut != -1): + test_data = test_data[:max_cut] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 50) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module('Mol_Instructions_postprocess_BioText') +def Mol_Instructions_postprocess_BioText(text, task, *args, **kwargs): + """ + Extract the protein str between and in the sentences + """ + text = text.strip() + if task in ( + 'chemical_disease_interaction_extraction', + 'chemical_protein_interaction_extraction', + 'chemical_entity_recognition', + 'true_or_false_question', + 'multi_choice_question', + 'open_question', + ): + # For property prediction, we only need the first line of the text + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + + # remove "Response: " or "Answer: " at the beginning for qwen3 + text = re.sub(r'^(Response:|Answer:)\s*', + '', + text, + flags=re.IGNORECASE) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + # remove the sentences before for gpt-oss-120b + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + + # remove "I would say that" or + # "I would like to say that" at the beginning for qwen3 + text = re.sub(r'^(I would say that|I would like to say that)\s*', + '', + text, + flags=re.IGNORECASE) + text = text.strip() + else: + pass + return text + + +class Mol_Instructions_Evaluator_BioText(BaseEvaluator): + + def __init__(self, task='protein_design', *args, **kwargs): + super().__init__(*args, **kwargs) + self.task = task + + def score(self, predictions: List[str], references: List[str]): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + + if self.task in ( + 'chemical_disease_interaction_extraction', + 'chemical_protein_interaction_extraction', + ): + results = { + 'f1': calculate_accuracy_(predictions, references), + } + elif self.task in ('chemical_entity_recognition', ): + results = { + 'f1': CER_calculate_accuracy_(predictions, references), + } + elif self.task == 'true_or_false_question': + acc, other_answers = ture_or_false_calculate_accuracy_( + predictions, references) + results = { + 'accuracy': acc, + 'other_answers': other_answers, + } + elif self.task == 'multi_choice_question': + results = { + 'accuracy': + multi_choice_question_calculate_accuracy_( + predictions, references), + } + elif self.task == 'open_question': + from bert_score import score + correct_answers = [ref[0] for ref in references] + my_answers = [pred[0] for pred in predictions] + P, R, F1 = score(my_answers, + correct_answers, + lang='en', + verbose=False, + num_layers=14, + model_type='FacebookAI/roberta-large') + + results = { + # 'bleu': total_bleu/len(my_answers), + # 'rouge': total_rouge/len(my_answers), + 'bert_score': sum(F1).item() / len(F1), + } + else: + raise ValueError(f'Unknown task: {self.task}') + + return results diff --git a/opencompass/datasets/SciReasoner/Mol_Instructions/molecule.py b/opencompass/datasets/SciReasoner/Mol_Instructions/molecule.py new file mode 100644 index 000000000..743152c4a --- /dev/null +++ b/opencompass/datasets/SciReasoner/Mol_Instructions/molecule.py @@ -0,0 +1,458 @@ +# flake8: noqa +# molecule task +# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule + +import json +import re + +import numpy as np +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from Levenshtein import distance as lev +from nltk.translate.bleu_score import corpus_bleu +from nltk.translate.meteor_score import meteor_score + +try: + from rdkit import Chem, DataStructs, RDLogger + from rdkit.Chem import AllChem, MACCSkeys +except Exception: + Chem, DataStructs, RDLogger, AllChem, MACCSkeys = None, None, None, None, None + +try: + import selfies as sf +except Exception: + sf = None + +import os + +from rouge_score import rouge_scorer +from sklearn.metrics import mean_absolute_error +from transformers import BertTokenizerFast + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + +# RDLogger.DisableLog('rdApp.*') + + +@LOAD_DATASET.register_module() +class Mol_Instructions_Dataset(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + + if (max_cut != -1): + test_data = test_data[:max_cut] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +def convert_to_canonical_smiles(smiles): + molecule = Chem.MolFromSmiles(smiles) + if molecule is not None: + canonical_smiles = Chem.MolToSmiles(molecule, + isomericSmiles=False, + canonical=True) + return canonical_smiles + else: + return None + + +@TEXT_POSTPROCESSORS.register_module() +def Mol_Instructions_postprocess_Mol(text, task, *args, **kwargs): + """ + Filter end tokens in the sentences: "<|endoftext|>","<|im_end|>" + """ + if task == 'property_prediction_str': + # For property prediction, we only need the first line of the text + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + text = re.sub(r'(?<=\d) +(?=\d)|(?<=\.) +(?=\d)', '', text) + num_match = re.search(r'[-+]?\d*\.\d+|\d+', text) + text = num_match.group(0) if num_match else 0 + elif task in [ + 'description_guided_molecule_design', + 'forward_reaction_prediction', + 'retrosynthesis', + 'reagent_prediction', + ]: + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + # first filter the pattern + + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + + pattern = r'(.*?)' + match = re.search(pattern, text) + if match: + smiles = match.group(1).strip() + text = convert_to_canonical_smiles(smiles) + else: + # print('No SMILES found in the text. Using the original text.') + # print(text) + # import pdb; pdb.set_trace() + text = None # generate a false SMILES to avoid error in evaluation + elif task in [ + 'molecular_description_generation', + ]: + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + + return text + + +def compute_MAE_property_prediction_str(predictions, references): + y_pred = np.array([float(p[0]) for p in predictions]) + y_true = np.array([float(r[0]) for r in references]) + mae = mean_absolute_error( + y_true, + y_pred) * 1000 # scale to match the presentation of Opencompass + return {'mae': mae} + + +def compute_fingerprint_metricts( + predictions, + references, + morgan_r=2, +): + bad_mols = 0 + outputs = [] + + for pred, refer in zip(predictions, references): + try: + if pred[0] is None: + bad_mols += 1 + continue + pred_ = Chem.MolFromSmiles(pred[0]) + refer_ = Chem.MolFromSmiles(refer[0]) + if pred_ is None: + # print(pred) + bad_mols += 1 + continue + outputs.append((refer_, pred_)) + except Exception: + import pdb + pdb.set_trace() + + validity_score = len(outputs) / (len(outputs) + bad_mols) + + MACCS_sims = [] + morgan_sims = [] + RDK_sims = [] + + enum_list = outputs + + for i, (gt_m, ot_m) in enumerate(enum_list): + # if i % 100 == 0: + # if verbose: print(i, 'processed.') + + MACCS_sims.append( + DataStructs.FingerprintSimilarity( + MACCSkeys.GenMACCSKeys(gt_m), + MACCSkeys.GenMACCSKeys(ot_m), + metric=DataStructs.TanimotoSimilarity)) + RDK_sims.append( + DataStructs.FingerprintSimilarity( + Chem.RDKFingerprint(gt_m), + Chem.RDKFingerprint(ot_m), + metric=DataStructs.TanimotoSimilarity)) + morgan_sims.append( + DataStructs.TanimotoSimilarity( + AllChem.GetMorganFingerprint(gt_m, morgan_r), + AllChem.GetMorganFingerprint(ot_m, morgan_r))) + + maccs_sims_score = np.mean(MACCS_sims) + rdk_sims_score = np.mean(RDK_sims) + morgan_sims_score = np.mean(morgan_sims) + + return { + 'validity_score': validity_score, + 'maccs_sims_score': maccs_sims_score, + 'rdk_sims_score': rdk_sims_score, + 'morgan_sims_score': morgan_sims_score + } + + +def compute_mol_translation_selfies(predictions, references): + outputs = [] + bad_mols = 0 + print(f'predictions: {predictions}, references: {references}') + for pred, refer in zip(predictions, references): + if pred[0] is None: + bad_mols += 1 + continue + pred_canonical_smiles = pred[0] + refer_canonical_smiles = refer[0] + try: + pred_sf = sf.encoder(pred_canonical_smiles) + refer_sf = sf.encoder(refer_canonical_smiles) + except Exception: + bad_mols += 1 + continue + + outputs.append( + (refer_sf, pred_sf, refer_canonical_smiles, pred_canonical_smiles)) + + references_self = [] + hypotheses_self = [] + + references_smi = [] + hypotheses_smi = [] + + for i, (gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs): + gt_self_tokens = [c for c in gt_self] + out_self_tokens = [c for c in ot_self] + + references_self.append([gt_self_tokens]) + hypotheses_self.append(out_self_tokens) + + gt_smi_tokens = [c for c in gt_smi] + ot_smi_tokens = [c for c in ot_smi] + + references_smi.append([gt_smi_tokens]) + hypotheses_smi.append(ot_smi_tokens) + + # BLEU score + if not references_self or not hypotheses_self: + bleu_score_self = 0.0 + else: + bleu_score_self = corpus_bleu(references_self, hypotheses_self) + + references_self = [] + hypotheses_self = [] + + references_smi = [] + hypotheses_smi = [] + + levs_self = [] + levs_smi = [] + + num_exact = 0 + + i = 0 + for i, (gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs): + + hypotheses_self.append(ot_self) + references_self.append(gt_self) + + hypotheses_smi.append(ot_smi) + references_smi.append(gt_smi) + + try: + m_out = Chem.MolFromSmiles(ot_smi) + m_gt = Chem.MolFromSmiles(gt_smi) + + if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): + num_exact += 1 + # if gt == out: num_exact += 1 + # old version that didn't standardize strings + except Exception: + bad_mols += 1 + + levs_self.append(lev(ot_self, gt_self)) + levs_smi.append(lev(ot_smi, gt_smi)) + + # Exact matching score + exact_match_score = num_exact / (i + 1) + # if verbose: + # print('Exact Match:') + # print(exact_match_score) + + # Levenshtein score + levenshtein_score_smi = np.mean(levs_smi) + # if verbose: + # print('SMILES Levenshtein:') + # print(levenshtein_score_smi) + + return { + 'bleu_self_scores': bleu_score_self, + 'exact_match_score': exact_match_score, + 'levenshtein_score_smi': levenshtein_score_smi, + } + + +def fix_smiles_brackets(smiles): + """修复SMILES字符串中缺失的右括号""" + if not isinstance(smiles, str): + return smiles + + left_count = smiles.count('(') + right_count = smiles.count(')') + missing = left_count - right_count + + if missing > 0: + return smiles + ')' * missing + return smiles + + +class Mol_Instructions_Evaluator_Mol(BaseEvaluator): + + def __init__(self, task, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task = task + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + # import pdb;pdb.set_trace() + task = self.task + pred_list = predictions + gold_list = references + + if task in ('property_prediction_str', ): + results = compute_MAE_property_prediction_str(pred_list, gold_list) + elif task in ('description_guided_molecule_design', + 'forward_reaction_prediction', 'retrosynthesis', + 'reagent_prediction'): + fingerprint_metrics = compute_fingerprint_metricts( + pred_list, gold_list) + mol_translation_selfies = compute_mol_translation_selfies( + pred_list, gold_list) + # Combine the results from both computations + results = {**fingerprint_metrics, **mol_translation_selfies} + # change the order to + # 'exact', 'blue', 'levenshtein', 'RDK', + # 'MACCS', 'Morgan', 'validity' + results = { + 'exact_match_score': results['exact_match_score'], + 'bleu_self_scores': results['bleu_self_scores'], + 'levenshtein_score_smi': results['levenshtein_score_smi'], + 'rdk_sims_score': results['rdk_sims_score'], + 'maccs_sims_score': results['maccs_sims_score'], + 'morgan_sims_score': results['morgan_sims_score'], + 'validity_score': results['validity_score'] + } + elif task in ('molecular_description_generation', ): + results = compute_text_translation_metrics(pred_list, gold_list) + else: + raise ValueError(task) + + return results + + +def compute_text_translation_metrics( + predictions, + references, + text_model='allenai/scibert_scivocab_uncased', + text_trunc_length=512): + outputs = [] + + for pred, refer in zip(predictions, references): + try: + pred_ = pred[0].rsplit('.', 1)[0] + '.' if isinstance( + pred[0], str) else pred[0] + outputs.append((refer[0], pred_)) + except Exception: + import pdb + pdb.set_trace() + + text_tokenizer = BertTokenizerFast.from_pretrained(text_model) + + meteor_scores = [] + + references = [] + hypotheses = [] + + for i, (gt, out) in enumerate(outputs): + gt_tokens = text_tokenizer.tokenize(gt, + truncation=True, + max_length=text_trunc_length, + padding='max_length') + gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens)) + gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens)) + gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens)) + + out_tokens = text_tokenizer.tokenize(out, + truncation=True, + max_length=text_trunc_length, + padding='max_length') + out_tokens = list(filter(('[PAD]').__ne__, out_tokens)) + out_tokens = list(filter(('[CLS]').__ne__, out_tokens)) + out_tokens = list(filter(('[SEP]').__ne__, out_tokens)) + + references.append([gt_tokens]) + hypotheses.append(out_tokens) + + mscore = meteor_score([gt_tokens], out_tokens) + meteor_scores.append(mscore) + + bleu2 = corpus_bleu(references, hypotheses, weights=(.5, .5)) + bleu4 = corpus_bleu(references, hypotheses, weights=(.25, .25, .25, .25)) + + _meteor_score = np.mean(meteor_scores) + + scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL']) + + rouge_scores = [] + + references = [] + hypotheses = [] + + for i, (gt, out) in enumerate(outputs): + rs = scorer.score(out, gt) + rouge_scores.append(rs) + + rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) + rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) + rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) + + return { + 'bleu2': bleu2, + 'bleu4': bleu4, + 'meteor_score': _meteor_score, + 'rouge1': rouge_1, + 'rouge2': rouge_2, + 'rougeL': rouge_l + } diff --git a/opencompass/datasets/SciReasoner/Mol_Instructions/normalized_SW_score.py b/opencompass/datasets/SciReasoner/Mol_Instructions/normalized_SW_score.py new file mode 100644 index 000000000..f35fd654a --- /dev/null +++ b/opencompass/datasets/SciReasoner/Mol_Instructions/normalized_SW_score.py @@ -0,0 +1,150 @@ +import math + + +def normalized_smith_waterman(seq1, + seq2, + matrix_name='BLOSUM45', + open_gap=-10, + extend_gap=-0.5): + """ + Compute normalized Smith-Waterman score for protein sequences. + + Args: + seq1, seq2 (str): Protein sequences (uppercase letters) + matrix_name (str): Name of substitution matrix (default: BLOSUM62) + open_gap (float): Gap opening penalty + extend_gap (float): Gap extension penalty + + Returns: + float: Normalized score between 0.0 and 1.0 + """ + + from Bio.Align import PairwiseAligner, substitution_matrices + + # Initialize aligner + aligner = PairwiseAligner() + aligner.mode = 'local' # Smith-Waterman algorithm + aligner.open_gap_score = open_gap + aligner.extend_gap_score = extend_gap + + # Load substitution matrix + try: + matrix = substitution_matrices.load(matrix_name) + except ValueError: + raise ValueError(f'Matrix {matrix_name} not available.' + f' Try: {substitution_matrices.load()}') + + # Set substitution matrix + aligner.substitution_matrix = matrix + + # Calculate raw alignment score + raw_score = aligner.score(seq1, seq2) + if raw_score <= 0: + return 0.0 + + # Calculate self-alignment scores + def calc_self_score(seq): + """Calculate maximum possible self-alignment score""" + score = 0 + for aa in seq: + try: + # Try direct lookup + score += matrix[aa, aa] + except KeyError: + # Try reverse lookup for symmetric matrices + score += matrix[aa, aa] # Same residue + return score + + self_score1 = calc_self_score(seq1) + self_score2 = calc_self_score(seq2) + + # Handle invalid self-scores + if self_score1 <= 0 or self_score2 <= 0: + return 0.0 + + # Compute normalization factor (geometric mean) + norm_factor = math.sqrt(self_score1 * self_score2) + + return min(raw_score / norm_factor, 1.0) + + +# 示例用法 +if __name__ == '__main__': + # 示例序列(可以替换为实际的蛋白质序列) + # target_sequence = "MGGKWSKSSIVGWPAVRERIRQTEPRTEPAA" # 目标序列 + # generated_sequence = "MGGKWSKSSIVGWPAVRERIRRTEPAA" # 模型生成的序列 + # + # # target_sequence = 'MSTNPKPQRKTKRNTNRRPQDVKFPGGG' + # # generated_sequence = 'MSTNPKPQRKTKRNTNRRPQDVK' + # + # # 计算归一化 SW 得分 + # normalized_score = calculate_normalized_sw_score( + # target_sequence, + # generated_sequence, + # gap_open=-10, + # gap_extend=-0.5, + # match_score=2, + # mismatch_score=-1 + # ) + # + # print(f"归一化 SW 得分: {normalized_score:.3f}") + # + # # 计算归一化 Smith-Waterman 得分 + # normalized_sw_score = normalized_smith_waterman( + # target_sequence, + # generated_sequence, + # ) + # print(f"归一化 Smith-Waterman 得分: {normalized_sw_score:.4f}") + import json + import os + import re + + def Mol_Instructions_postprocess_Protein_Design(text, *args, **kwargs): + """ + Extract the protein str between + and in the sentences + """ + text = text.strip() + pattern = r'(.*?)' + match = re.search(pattern, text) + if match: + text = match.group(1) + # filter to make sure letters are all in the alphabet + valid_letters = set('ACDEFGHIKLMNPQRSTVWY') + text = ''.join(filter(lambda x: x in valid_letters, text)) + else: + text = '' + return text + + pred_list = [] + gt_list = [] + scores = [] + json_dir = ( + '/root/code/opencompass-sci/outputs/protein/mol_instructions/' + '20250619_185027/predictions/qwen3-1.7B-sft-protein_0.7T_0.9p_50k') + for filename in os.listdir(json_dir): + if filename.endswith('.json'): + file_path = os.path.join(json_dir, filename) + with open(file_path, 'r') as f: + data = json.load(f) + for key, value in data.items(): + pred = Mol_Instructions_postprocess_Protein_Design( + value['prediction']) + gt = Mol_Instructions_postprocess_Protein_Design( + value['gold']) + pred_list.append(pred) + gt_list.append(gt) + if not pred or not gt: + scores.append(0.0) + else: + # Calculate the normalized Smith-Waterman score + try: + score = normalized_smith_waterman(pred, gt) + scores.append(score) + except Exception: + import pdb + + pdb.set_trace() + import pdb + + pdb.set_trace() diff --git a/opencompass/datasets/SciReasoner/Mol_Instructions/protein.py b/opencompass/datasets/SciReasoner/Mol_Instructions/protein.py new file mode 100644 index 000000000..530d9b38b --- /dev/null +++ b/opencompass/datasets/SciReasoner/Mol_Instructions/protein.py @@ -0,0 +1,155 @@ +# flake8: noqa +# molecule task +# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule + +import json +import os +import re +from typing import List, Optional + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from mmengine.config import ConfigDict + +from opencompass.datasets.base import BaseDataset +from opencompass.datasets.SciReasoner.Mol_Instructions.normalized_SW_score import \ + normalized_smith_waterman +from opencompass.openicl import BaseEvaluator, RougeEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +@LOAD_DATASET.register_module() +class Mol_Instructions_Dataset_Protein_Design(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + # import pdb; pdb.set_trace() + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + + if (max_cut != -1): + test_data = test_data[:max_cut] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module('Mol_Instructions_postprocess_Protein') +def Mol_Instructions_postprocess_Protein(text, *args, **kwargs): + """ + Filter end tokens in the sentences: "<|endoftext|>","<|im_end|>" + """ + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + text = text.strip() + + return text + + +class Mol_Instructions_Evaluator_Protein(RougeEvaluator): + + def __init__(self, + task='catalytic_activity', + pred_postprocessor: Optional[ConfigDict] = None): + super().__init__(pred_postprocessor=pred_postprocessor, ) + self.task = task + + +@TEXT_POSTPROCESSORS.register_module( + 'Mol_Instructions_postprocess_Protein_Design') +def Mol_Instructions_postprocess_Protein_Design(text, *args, **kwargs): + """ + Extract the protein str between and in the sentences + """ + text = text.strip() + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + pattern = r'(.*?)' + match = re.search(pattern, text) + if match: + text = match.group(1) + valid_letters = set('ACDEFGHIKLMNPQRSTVWY') + text = ''.join(filter(lambda x: x in valid_letters, text)) + else: + text = '' + return text + + +class Mol_Instructions_Evaluator_Protein_Design(BaseEvaluator): + + def __init__(self, task='protein_design', *args, **kwargs): + super().__init__(*args, **kwargs) + self.task = task + + def score(self, predictions: List[str], references: List[str]): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + + scores = [] + for pred, refer in zip(predictions, references): + pred = pred[0].strip() + refer = refer[0].strip() + if not pred or not refer: + scores.append(0.0) + else: + # Calculate the normalized Smith-Waterman score + score = normalized_smith_waterman( + pred, refer) * 100 # Convert to percentage + scores.append(score) + + averaged_valid_scores = [score for score in scores if score > 0] + + results = { + 'Max SW score': + max(scores), + 'Min SW score': + min(scores), + 'Average SW score': + sum(scores) / len(scores), + 'valid average SW score': + sum(averaged_valid_scores) / + len(averaged_valid_scores) if averaged_valid_scores else 0.0, + } + return results diff --git a/opencompass/datasets/SciReasoner/PEER.py b/opencompass/datasets/SciReasoner/PEER.py new file mode 100644 index 000000000..3ac219c87 --- /dev/null +++ b/opencompass/datasets/SciReasoner/PEER.py @@ -0,0 +1,471 @@ +# flake8: noqa +# dataset: PEER +# task : solubility prediction + +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Union + +import numpy as np +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download +from openai import OpenAI +from sklearn.metrics import f1_score, precision_score, recall_score + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +@LOAD_DATASET.register_module() +class PEER_Dataset(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + + if (max_cut != -1): + test_data = test_data[:max_cut] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def PEER_postprocess_default(text: Union[str, None]) -> str: + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + return text + + +@TEXT_POSTPROCESSORS.register_module() +def PEER_postprocess(text: Union[str, None]) -> str: + """ + 从模型的原始输出中提取预测结果(Yes或No)。 + + 此函数会查找并返回跟在The answer is后面的Yes或者No, + 或从文本中识别常见的Yes/No表达方式。 + """ + # 检查输入是否为字符串,提高代码健壮性 + if not isinstance(text, str): + return '' + # 定义正则表达式模式,匹配常见的Yes/No表达方式 + # 首先检查是否有明确的"The answer is Yes/No"模式 + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + match = re.search(r'The answer is\s+(Yes|No)', text, re.IGNORECASE) + if match: + return match.group(1) + + # 检查常见的肯定表达方式 + positive_patterns = [ + r'will be soluble', + r'will dissolve', + r'is soluble', + r'can be predicted', + r'positive', + r'Yes', + r'correct', + r'valid', + r'accurate', + r'certainly', + r'indeed', + r'affirmative', + r'highly soluble', + r'easily soluble', + r'dissolves easily', + r'is assured', + # r'likely', + r'be soluble' + ] + + # 检查常见的否定表达方式 + negative_patterns = [ + r'will not be soluble', + r'is not soluble', + r'will not dissolve', + r'low solubility', + r'low', + r'cannot be predicted', + r'negative', + r'No', + r'incorrect', + r'invalid', + r'inaccurate', + r'impossible', + r'not possible', + r'denied', + r'be insoluble', + ] + + # 检查是否包含肯定表达 + for pattern in positive_patterns: + if re.search(pattern, text, re.IGNORECASE): + return 'Yes' + + # 检查是否包含否定表达 + for pattern in negative_patterns: + if re.search(pattern, text, re.IGNORECASE): + return 'No' + + # 若无法识别,返回空字符串 + return '' + + +@TEXT_POSTPROCESSORS.register_module() +def PEER_postprocess_float_compare(text: Union[str, None], + compare_number: float) -> str: + # 从模型的输出中匹配预测的数值,与compare_number进行比较, 大于则返回"Yes",否则返回"No" + if not isinstance(text, str): + return '' + try: + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?\s*', '', text, flags=re.DOTALL) + # 提取文本中的数字 + match = re.search(r'[-+]?\d*\.\d+|\d+', text) + if match: + value = float(match.group(0)) + # 比较数值 + if value > compare_number: + return 'Yes' + else: + return 'No' + else: + # 如果没有找到数字,返回空字符串 + return '' + except ValueError: + # 如果转换失败,返回空字符串 + return '' + + +def calculate_accuracy(pred_text_list, gold_text_list): + assert len(pred_text_list) == len(gold_text_list) + num_all = len(pred_text_list) + metrics = {} + metrics['num_all'] = num_all + num_no_answer = 0 + num_invalid = 0 + num_correct = 0 + new_pred_text_list, new_gold_text_list = [], [] + for (pred_item, gold_item) in zip(pred_text_list, gold_text_list): + if pred_item is None or pred_item == '': + num_no_answer += 1 + continue + assert len(pred_item) == 1 + assert len(gold_item) == 1 + pred_item = pred_item[0].strip().lower() + gold_item = gold_item[0].strip().lower() + if pred_item == '': + num_no_answer += 1 + continue + if pred_item not in ('yes', 'no'): + num_invalid += 1 + continue + pred_item = 1 if pred_item == 'yes' else 0 + gold_item = 1 if gold_item == 'yes' else 0 + new_pred_text_list.append(pred_item) + new_gold_text_list.append(gold_item) + if gold_item == pred_item: + num_correct += 1 + + metrics['num_no_answer'] = num_no_answer + metrics['num_invalid'] = num_invalid + metrics['num_correct'] = num_correct + + # return metrics + + new_gold_text_list = np.array(new_gold_text_list) + new_pred_text_list = np.array(new_pred_text_list) + + # macro_roc_auc_score = + # roc_auc_score(new_gold_text_list, new_pred_text_list) + f1 = f1_score(new_gold_text_list, new_pred_text_list) + # metrics['roc_auc_score'] = macro_roc_auc_score + metrics['accuracy'] = num_correct / (num_all) * 100 + metrics['acc_wo_no_answer_invalid'] = num_correct / ( + num_all - num_no_answer - num_invalid) * 100 if ( + num_all - num_no_answer - num_invalid) > 0 else 0 + metrics['precision'] = precision_score(new_gold_text_list, + new_pred_text_list) * 100 + metrics['recall'] = recall_score(new_gold_text_list, + new_pred_text_list) * 100 + metrics['f1_score'] = f1 * 100 + + return metrics + + +# ---------------------------------------------------------------------- +# 定义 Evaluator (评估器) - 这是修改的核心 +# ---------------------------------------------------------------------- + +MAX_RETRIES = 3 +BACKOFF_SEC = 2 + + +class PEER_Evaluator(BaseEvaluator): + + def __init__(self, + task='solubility', + gpt_model='gpt-4', + openai_key='xxx', + use_gpt=True, + max_workers=8, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.task = task + self.gpt_model = gpt_model + self.use_gpt = use_gpt + self.max_workers = max_workers + + if task in [ + 'stability', + ]: + self.use_gpt = False + + if self.use_gpt: + if not openai_key: + raise ValueError('OpenAI API key is missing.') + self.client = OpenAI(base_url='url', api_key=openai_key) + + def _retry_api(self, fn, *args, **kwargs): + last_exc = None + for attempt in range(1, MAX_RETRIES + 1): + try: + result = fn(*args, **kwargs) + if result is not None: + return result + raise ValueError('Received None') + except Exception as e: + last_exc = e + sleep_time = BACKOFF_SEC**attempt + print(f'[retry] attempt {attempt} failed ({e}),' + f' retrying in {sleep_time}s…') + time.sleep(sleep_time) + raise last_exc + + def ask_gpt25(self, question, answer, prediction): + + prompt = ( + 'Please determine whether this answer is correct. Definition:' + "'Correct': The core conclusion of the model's answer (if any) is " + 'completely consistent with the reference answer (literal identity' + " is not required). 'Incorrect': The core conclusion of the" + " model's answer is consistent with the reference answer, or the" + ' core conclusion is not clearly expressed. Reference answer' + f': {answer}' + f'Model answer: {prediction}' + "If correct, answer 'True'; if incorrect, answer 'False'." + "Please only answer 'True' or 'False'.") + + def _call(): + response = self.client.chat.completions.create( + model=self.gpt_model, + messages=[{ + 'role': 'user', + 'content': prompt + }], + temperature=0) + + result = response.choices[0].message.content.strip().upper() + print('=== GPT 判断结果 ===') + print(f'Prompt:\n{prompt}') + print(f'Output:\n{result}') + return result + + try: + return self._retry_api(_call) + except Exception as e: + print(f'[GPT ERROR] Exception: {e}') + return '' + + def ask_gpt25_batch(self, questions, answers, predictions): + results = [None] * len(questions) + + def task(index, q, a, p): + try: + result = self.ask_gpt25(q, a, p) + results[index] = result + except Exception as e: + results[index] = '' + print(f'[GPT ERROR] 批次样本 {index} 出错: {e}') + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [ + executor.submit(task, i, q, a, p) + for i, (q, a, + p) in enumerate(zip(questions, answers, predictions)) + ] + for future in as_completed(futures): + pass + + return results + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different length' + } + + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + + postprocessed_references = [[PEER_postprocess(r[0]).strip().lower()] + for r in references] + postprocessed_predictions = [[PEER_postprocess(p[0]).strip().lower()] + for p in predictions] + + voted_prediction = [] + for pred in postprocessed_predictions: + valid_pred = [p for p in pred if p in ['yes', 'no']] + cnt = valid_pred.count('yes') + if cnt > len(valid_pred) / 2: + voted = 'yes' + elif cnt < len(valid_pred) / 2: + voted = 'no' + else: + voted = '' + voted_prediction.append([voted]) + + num_all = len(voted_prediction) + num_correct, num_no_answer, num_invalid = 0, 0, 0 + num_gpt_called = 0 + new_pred, new_gold = [], [] + + to_recheck_indices = [] + to_recheck_golds = [] + to_recheck_preds = [] + + for i, (pred_item, gold_item) in enumerate( + zip(postprocessed_predictions, postprocessed_references)): + pred = pred_item[0] + gold = gold_item[0] + + if pred not in ('yes', 'no'): + to_recheck_indices.append(i) + to_recheck_golds.append(references[i][0]) + to_recheck_preds.append(predictions[i][0]) + continue + + if pred == 'yes': + pred_bin = 1 + elif pred == 'no': + pred_bin = 0 + else: + to_recheck_indices.append(i) + to_recheck_golds.append(references[i][0]) + to_recheck_preds.append(predictions[i][0]) + continue + + if gold == 'yes': + gold_bin = 1 + elif gold == 'no': + gold_bin = 0 + else: + to_recheck_indices.append(i) + to_recheck_golds.append(references[i][0]) + to_recheck_preds.append(predictions[i][0]) + continue + + if pred_bin == gold_bin: + num_correct += 1 + # import pdb; pdb.set_trace() + print(references[i][0], '\n', predictions[i][0], '----') + new_pred.append(pred_bin) + new_gold.append(gold_bin) + else: + to_recheck_indices.append(i) + to_recheck_golds.append(references[i][0]) + to_recheck_preds.append(predictions[i][0]) + + if to_recheck_indices and self.use_gpt: + rechecked_preds = self.ask_gpt25_batch( + ['' for _ in to_recheck_indices], to_recheck_golds, + to_recheck_preds) + num_gpt_called += len(rechecked_preds) + + for i, result in enumerate(rechecked_preds): + result = result.strip().lower() + if 'true' in result: + num_correct += 1 + pred_bin = 1 + gold_bin = 1 + elif 'false' in result: + pred_bin = 0 + gold_bin = 1 + else: + pred_bin = 1 + gold_bin = 0 + + new_pred.append(pred_bin) + new_gold.append(gold_bin) + + new_pred = np.array(new_pred) + new_gold = np.array(new_gold) + + metrics = { + 'num_all': + num_all, + 'num_correct': + num_correct, + 'num_no_answer': + num_no_answer, + 'num_invalid': + num_invalid, + 'num_gpt_called': + num_gpt_called, + 'accuracy': + num_correct / num_all * 100, + 'acc_wo_no_answer_invalid': + num_correct / (num_all - num_no_answer - num_invalid) * 100 if + (num_all - num_no_answer - num_invalid) > 0 else 0, + 'precision': + precision_score(new_gold, new_pred, zero_division=0) * 100, + 'recall': + recall_score(new_gold, new_pred, zero_division=0) * 100, + 'f1_score': + f1_score(new_gold, new_pred, zero_division=0) * 100, + } + + return metrics diff --git a/opencompass/datasets/SciReasoner/__init__.py b/opencompass/datasets/SciReasoner/__init__.py new file mode 100644 index 000000000..4383f6b19 --- /dev/null +++ b/opencompass/datasets/SciReasoner/__init__.py @@ -0,0 +1,13 @@ +from .bio_instruction import * # noqa: F401, F403 +from .bulk_modulus_material import * # noqa: F401, F403 +from .composition_material import * # noqa: F401, F403 +from .GUE import * # noqa: F401, F403 +from .LLM4Chem import * # noqa: F401, F403 +from .LLM4Mat import * # noqa: F401, F403 +from .Mol_Instructions import * # noqa: F401, F403 +from .opi import * # noqa: F401, F403 +from .PEER import * # noqa: F401, F403 +from .uncond_material import * # noqa: F401, F403 +from .uncond_RNA import * # noqa: F401, F403 +from .unconditional_molecule_generation import * # noqa: F401, F403 +from .unconditional_protein_generation import * # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/bio_instruction.py b/opencompass/datasets/SciReasoner/bio_instruction.py new file mode 100644 index 000000000..a163f5fc8 --- /dev/null +++ b/opencompass/datasets/SciReasoner/bio_instruction.py @@ -0,0 +1,1440 @@ +# flake8: noqa + +import json +import os +import re +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import torch +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download + +try: + from scipy.stats import pearsonr, spearmanr +except Exception: + pearsonr, spermanr = None, None +from sklearn.metrics import (accuracy_score, matthews_corrcoef, + mean_absolute_error, mean_squared_error, + precision_score, recall_score, roc_auc_score) +from tqdm import tqdm +from transformers import pipeline + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.utils import get_data_path + +# current_working_directory = os.getcwd() +# path_bioinstruction = os.path.join(current_working_directory, 'OpenCompass_SciReasoner_extra_data', +# 'datasets', 'bioinstruction') + + +# @LOAD_DATASET.register_module() +class Bioinstruction_Dataset(BaseDataset): + + @staticmethod + def load(path, task, mini_set=False, hf_hub=False): + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + train_data = train_data[:5] + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + selected_train_data = [{ + 'input': record['input'], + 'output': record['output'] + } for record in train_data] + selected_test_data = [{ + 'input': record['input'], + 'output': record['output'] + } for record in test_data] + # dataset=Dataset.from_list(selected_train_data) + if mini_set and len(selected_test_data) > 150: + import random + random.seed(1024) + selected_test_data = random.sample(selected_test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(selected_train_data), + 'test': Dataset.from_list(selected_test_data) + }) + return dataset + + +def extract_answer_part(outputs, left_tag, right_tag, mode='tag'): + assert mode in ('tag', 'direct') + + assert isinstance(outputs, list) + answers = [] + for text in outputs: + if mode == 'direct' or (left_tag is None and right_tag is None): + text = text.replace('', '').replace('', '').strip() + answers.append(text.strip()) + continue + + left_tag_pos = text.find(left_tag) + if left_tag_pos == -1: + answers.append('') + continue + right_tag_pos = text.find(right_tag) + if right_tag_pos == -1: + answers.append('') + continue + text = text[left_tag_pos + len(left_tag):right_tag_pos].strip() + answers.append(text) + return answers + + +def extract_numeric_values(text): + text = text.replace("5'", "five'") + text = text.replace("3'", 'three') + + matches = re.findall(r'(?= 4: + # print(candidate) + return f' {candidate} ' + else: + return candidate + + +# Use the sentiment analysis model as fallback +# if classification by keywords fails +def classify_by_sentiment_model(text): + text = [ + str(t).replace('', '').replace('', '').strip() for t in text + ] + + candidate_labels = [ + 'Yes,I can positively identify', 'No,My answer is negative', + 'This protein is expected to dissolve in water', + 'This protein is not expected to dissolve in water' + ] + + classifier = pipeline('zero-shot-classification', + model='facebook/bart-large-mnli', + device=0) + + outputs = classifier(text, candidate_labels, batch_size=64) + processed_results = [] + for output in outputs: + # Hugging Face zero-shot pipeline默认按分数高低排序返回结果 + top_label = output['labels'][0] + top_score = output['scores'][0] + + if (top_label == 'Yes,I can positively identify' or top_label + == 'This protein is expected to dissolve in water'): + result_class = 1 + else: + result_class = 0 + + processed_results.append((result_class, top_score)) + return processed_results + + +def classify_by_keywords(text): + positive_keywords = [ + 'Yes', 'yes', 'positive', 'Positive', 'empirical', 'plausible', + 'confirms', 'have detected', 'are discernible', 'are supported', + 'is supported', 'display', 'detected the presence', 'shows evidence', + 'has been identified', 'shows', 'has identified', 'contains ', + 'exhibits evidence', 'is plausible', 'contains identifiable', 'Indeed', + 'reveals the presence', 'include', 'are present', 'definitely has', + 'soluble', 'displays regions', 'has a high solubility', + 'dissolves easily', 'Solubility is expected', + 'is expected to dissolve', 'is predicted', 'is likely', 'is expected', + 'is expected to dissolve', 'will dissolve', 'dissolves easily' + ] + + negative_keywords = [ + 'No', 'no', 'negative', 'Negative', 'insoluble', 'does not', + 'unlikely', 'absence', 'not found', 'not detected', 'not associated', + 'not inferred', 'not linked', 'does not indicate', 'no evidence', + 'not predicted', 'absent', 'not present', 'no indicators', + 'not exhibit', 'are absent', 'found none', 'did not reveal', 'lacks', + 'exhibits no', 'insolubility', 'low solubility', 'not soluble', + 'not be soluble', 'does not display regions', 'cannot confirm' + ] + + dont_know_keywords = [ + 'don\'t know', 'unknown', 'unsure', 'uncertain', 'not applicable', + 'cannot confirm' + ] + + text_lower = text.lower() + + # 为了安全,转义关键词中的特殊字符,并用'|'(或)连接 + # \b确保匹配的是整个单词 + negative_pattern = r'\b(' + '|'.join( + re.escape(kw) for kw in negative_keywords) + r')\b' + positive_pattern = r'\b(' + '|'.join( + re.escape(kw) for kw in positive_keywords) + r')\b' + dont_know_pattern = r'\b(' + '|'.join( + re.escape(kw) for kw in dont_know_keywords) + r')\b' + + # 1. 检查负面关键词 + if re.search(negative_pattern, text_lower): + return 0 + # 2. 检查正面关键词 + elif re.search(positive_pattern, text_lower): + return 1 + # 3. 检查 "不知道" 关键词 + elif re.search(dont_know_pattern, text_lower): + return 'dont_know' + else: + return None + + +# Save the processed data for each task in a separate file +# def save_processed_data(model_name, task_name, task_processed_data): +# +# dir_path = path_bioinstruction + f'/processed_data/{model_name}' +# file_path = f'{dir_path}/{task_name}_processed_data.json' +# os.makedirs(dir_path, exist_ok=True) +# with open(file_path, 'w') as outfile: +# json.dump(task_processed_data, outfile, indent=4) +# +# print(f'Task {task_name} procssed data saved in {file_path}') + + +# Process regression task +def process_regression_task(task_name, task_entries, model_name): + result_values = [] + label_values = [] + task_processed_data = [] + over_len = 0 + miss_len = 0 + for index, entry in enumerate(task_entries): + # print(entry) + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split( + '')[-1] + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + extracted_result = extract_numeric_values(entry['model_output']) + else: + if '' in entry['model_output']: + over_len += 1 + extracted_result = [] + else: + miss_len += 1 + extracted_result = extract_numeric_values( + entry['model_output']) + + label = float(entry['label']) + print('label', label) + print('extracted_result', extracted_result) + + if len(extracted_result + ) != 0 and extracted_result[0] > 80 and task_name == 'Isoform': + print(entry['model_output']) + extracted_result = [] + + if len(extracted_result) != 1: + print('not one:', entry['model_output']) + extracted_result = [] + + if len(extracted_result) == 0: + result_values.append( + np.inf) # Assign infinity if no valid result is extracted + else: + result_values.append( + extracted_result[0]) # Take the first valid extracted result + + label_values.append(label) + + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_model_ouput': + extracted_result[0] if len(extracted_result) > 0 else np.inf, + 'original_model_output': + entry['model_output'], + }) + + # save_processed_data(model_name, task_name, task_processed_data) + print('over_len: ', over_len) + print('miss_len: ', miss_len) + return label_values, result_values + + +# Compute spearman correlation +def compute_spearman(label_values, result_values): + if len(result_values) == 0: + return {'spearman': 'Error: Empty data'} + elif len(result_values) != len(label_values): + return { + 'spearman': + 'Error: Mismatch in the number of extracted numeric values' + } + + # Convert the label and result values to numpy arrays + result_values = np.array(result_values).flatten() + label_values = np.array(label_values).flatten() + + # Identify explicitly assigned infinity values + near_infinity_mask = np.isinf(result_values) + + # Exclude near-infinity pairs from the main calculation + valid_mask = ~near_infinity_mask & np.isfinite( + result_values) & np.isfinite(label_values) + valid_result_values = result_values[valid_mask] + valid_label_values = label_values[valid_mask] + + outlier_mask = valid_result_values <= 300 + + valid_result_values = valid_result_values[outlier_mask] + valid_label_values = valid_label_values[outlier_mask] + + # 初始化指标 + spearman = 0.0 + rmse = 0.0 + + # Compute Spearman correlation for valid values + if len(valid_result_values) > 0: + spearman, _ = spearmanr(valid_label_values, valid_result_values) + mse = mean_squared_error(valid_label_values, valid_result_values) + # 然后开方得到 RMSE + rmse = np.sqrt(mse) + + else: + spearman = 0 # Fallback if no valid pairs + + total_data_points = len(result_values) + total_valid_points = valid_mask.sum() + num_infinity_values = near_infinity_mask.sum() + + if num_infinity_values > 0: + final_spearman_score = (spearman * total_valid_points + + 0 * num_infinity_values) / total_data_points + else: + final_spearman_score = spearman # Edge case: no near-infinity values + print('rmse:', rmse) + + return {'spearman': final_spearman_score} + + +# Compute R2 +def compute_R2(label_values, result_values): + # from sklearn.metrics import r2_score + + # y_true = np.asarray(label_values, dtype=float).flatten() + # y_pred = np.asarray(result_values, dtype=float).flatten() + + # Check for empty data + if len(result_values) == 0: + return {'R2': 'Error: Empty data.'} + + # Check for equal length of arrays + elif len(result_values) != len(label_values): + return { + 'R2': 'Error: Mismatch in the number of extracted numeric values.' + } + + # Convert the label and result values to numpy arrays + result_values = np.array(result_values).flatten() + label_values = np.array(label_values).flatten() + + # Identify explicitly assigned infinity values + near_infinity_mask = np.isinf(result_values) + + # Exclude near-infinity pairs from the main calculation + valid_mask = ~near_infinity_mask & np.isfinite( + result_values) & np.isfinite(label_values) + valid_result_values = result_values[valid_mask] + valid_label_values = label_values[valid_mask] + + # Compute Pearson correlation coefficient for valid values + if len(valid_result_values) > 0: + try: + pcc, _ = pearsonr(valid_label_values, valid_result_values) + R2 = pcc**2 + # mse = mean_squared_error(valid_label_values, valid_result_values) + # 然后开方得到 RMSE + # rmse = np.sqrt(mse) + except Exception: + R2 = np.inf # Fallback to inf if computation fails + else: + R2 = 0 # Fallback if no valid pairs + + # Combine R2 score for valid and infinity values + total_data_points = len(result_values) + total_valid_points = valid_mask.sum() + num_infinity_values = near_infinity_mask.sum() + + if num_infinity_values > 0: + final_R2_score = (R2 * total_valid_points + + 0 * num_infinity_values) / total_data_points + else: + final_R2_score = R2 # Edge case: no near-infinity values + # print("RMSE:",rmse) + return {'R2': final_R2_score} + + +# Compute mixed score +def compute_mixed_score(label_values, + result_values, + threshold=30, + max_value=1e3): + rmse = 0.0 + if len(result_values) == 0: + return {'mixed_score': 'Error: Empty data.'} + elif len(result_values) != len(label_values): + return { + 'mixed_score': + 'Error: Mismatch in the number of extracted numeric values' + } + + # Convert the label and result values to numeric arrays + # using pandas to handle non-numeric entries + result_values = pd.to_numeric(result_values, errors='coerce').flatten() + label_values = pd.to_numeric(label_values, errors='coerce').flatten() + + # Identify near-infinity values + near_infinity_mask = np.abs(result_values) > max_value + if near_infinity_mask.any(): + print( + f'Warning: Found {sum(near_infinity_mask)} result values too large' + ' will be assigned a mixed score of 0. ' + f'Large result values: {result_values[near_infinity_mask]} ') + + # Exclude near-infinity pairs from the main calculation + valid_mask = ~near_infinity_mask & np.isfinite( + result_values) & np.isfinite(label_values) + valid_result_values = result_values[valid_mask] + valid_label_values = label_values[valid_mask] + + # Assign a mixed score of 0 to near-infinity pairs + num_infinity_values = near_infinity_mask.sum() + if num_infinity_values > 0: + mixed_score_infinity = 0 + + # Convert to binary based on the threshold for valid values + label_binary = (valid_label_values < threshold).astype(int) + result_binary = (valid_result_values < threshold).astype(int) + + # Compute precision, recall, F1 score for valid values + precision = precision_score(label_binary, result_binary, average='binary') + recall = recall_score(label_binary, result_binary, average='binary') + f1 = 2 * precision * recall / (precision + recall) if (precision + + recall) != 0 else 0 + + try: + # Compute mean absolute error (MAE) for valid values + mae = mean_absolute_error(valid_label_values, valid_result_values) + mse = mean_squared_error(valid_label_values, valid_result_values) + rmse = np.sqrt(mse) + + except ValueError: + mae = np.inf # Fallback to infinity if error occurs + + # Mask to keep only values in the range [0, threshold] for valid values + mask = (valid_result_values >= 0) & (valid_result_values <= threshold) + if mask.sum() > 0: + range_mae = mean_absolute_error(valid_label_values[mask], + valid_result_values[mask]) + else: + range_mae = 100 # Fallback if no values within the range + + # Ensure MAE and range_mae are within reasonable bounds to avoid overflow + mae = min(mae, 100) + range_mae = min(range_mae, 100) + + # Compute mixed score for valid values + mixed_score_valid = (1 - mae / 100) * 0.5 + (1 - + range_mae / 100) * f1 * 0.5 + print( + f'(1 - mae / 100) * 0.5={(1 - mae / 100) * 0.5}\n ' + f'(1 - range_mae / 100)={(1 - range_mae / 100)}\n ' + f'(1 - range_mae / 100) * f1 * 0.5={(1 - range_mae / 100) * f1 * 0.5}') + + # Compute the final mixed score, + # averaging in the score for the near-infinity pairs + total_data_points = len(result_values) + total_valid_points = valid_mask.sum() + + if num_infinity_values > 0: + final_mixed_score = ( + mixed_score_valid * total_valid_points + + mixed_score_infinity * num_infinity_values) / total_data_points + else: + # Edge case: no near-infinity values + final_mixed_score = mixed_score_valid + print('RMSE', rmse) + + return {'mixed_score': final_mixed_score} + + +# Programmable Switch task: +# multilabel regression output one average correlation +def compute_R2_for_ProgrammableRNASwitches_task(task_name, task_entries, + model_name): + on_result_values = [] + off_result_values = [] + on_off_result_values = [] + + on_label_values = [] + off_label_values = [] + on_off_label_values = [] + + task_processed_data = [] + over_len = 0 + miss_len = 0 + # Loop through each entry in the task + for entry in task_entries: + label = entry['label'] + # label = ast.literal_eval(label) + on_label = float(label['ON']) + off_label = float(label['OFF']) + on_off_label = float(label['ON_OFF']) + + # Extract numeric values from the model output + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + extracted_result = extract_numeric_values(entry['model_output']) + print('extracted_result', extracted_result) + + # Handle missing or invalid data by assigning np.nan + if len(extracted_result) != 3: + on_result_values.append(np.nan) + off_result_values.append(np.nan) + on_off_result_values.append(np.nan) + else: + on_result = extracted_result[0] + off_result = extracted_result[1] + on_off_result = extracted_result[2] + on_result_values.append(on_result) + off_result_values.append(off_result) + on_off_result_values.append(on_off_result) + + # Append the label values + on_label_values.append(on_label) + off_label_values.append(off_label) + on_off_label_values.append(on_off_label) + + # Save processed task data for this entry + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_model_output': { + 'ON': on_result if len(extracted_result) == 3 else np.nan, + 'OFF': off_result if len(extracted_result) == 3 else np.nan, + 'ON_Off': + on_off_result if len(extracted_result) == 3 else np.nan + }, + 'original_model_output': + entry['model_output'] + }) + + # Save the processed task data + # save_processed_data(model_name, task_name, task_processed_data) + + # Convert to numpy arrays for easier manipulation + on_result_values = np.array(on_result_values) + off_result_values = np.array(off_result_values) + on_off_result_values = np.array(on_off_result_values) + + on_label_values = np.array(on_label_values) + off_label_values = np.array(off_label_values) + on_off_label_values = np.array(on_off_label_values) + + # Filter out NaN values in ON, OFF, and ON/OFF result/label pairs + on_valid_mask = np.isfinite(on_result_values) & np.isfinite( + on_label_values) + off_valid_mask = np.isfinite(off_result_values) & np.isfinite( + off_label_values) + on_off_valid_mask = np.isfinite(on_off_result_values) & np.isfinite( + on_off_label_values) + + # Filter the valid ON, OFF, and ON/OFF values + on_result_values = on_result_values[on_valid_mask] + off_result_values = off_result_values[off_valid_mask] + on_off_result_values = on_off_result_values[on_off_valid_mask] + + on_label_values = on_label_values[on_valid_mask] + off_label_values = off_label_values[off_valid_mask] + on_off_label_values = on_off_label_values[on_off_valid_mask] + + try: + on_R2 = compute_R2( + on_result_values, + on_label_values)['R2'] if len(on_result_values) > 0 else 0 + except Exception: + on_R2 = 0 # Assign 0 in case of error + + try: + off_R2 = compute_R2( + off_result_values, + off_label_values)['R2'] if len(off_result_values) > 0 else 0 + except Exception: + off_R2 = 0 # Assign 0 in case of error + + try: + on_off_R2 = compute_R2( + on_off_result_values, + on_off_label_values)['R2'] if len(on_off_result_values) > 0 else 0 + except Exception: + on_off_R2 = 0 # Assign 0 in case of error + + # Combine R2 scores for ON, OFF, and ON/OFF values + total_on_points = max(len(on_result_values) + np.sum(~on_valid_mask), 1) + total_off_points = max(len(off_result_values) + np.sum(~off_valid_mask), 1) + total_on_off_points = max( + len(on_off_result_values) + np.sum(~on_off_valid_mask), 1) + + # Assign average R2 with 0 for invalid entries + final_on_R2 = (on_R2 * len(on_result_values)) / total_on_points if len( + on_result_values) > 0 else 0 + final_off_R2 = (off_R2 * len(off_result_values)) / total_off_points if len( + off_result_values) > 0 else 0 + final_on_off_R2 = (on_off_R2 * + len(on_off_result_values)) / total_on_off_points if len( + on_off_result_values) > 0 else 0 + + avg_R2 = (final_on_R2 + final_off_R2 + final_on_off_R2) / 3 + print('over_len: ', over_len) + print('miss_len: ', miss_len) + print('123', final_on_R2, final_off_R2, final_on_off_R2) + return {'R2': avg_R2} + + +# Enhancer Activity Task: +# multilabel regression output two individual correlation +def compute_PCC_for_enhancer_activity_task(task_name, task_entries, + model_name): + hk_result_values = [] + dev_result_values = [] + + hk_label_values = [] + dev_label_values = [] + + task_processed_data = [] + over_len = 0 + miss_len = 0 + # Loop through each entry in the task + for entry in task_entries: + label = entry['label'] + # label = ast.literal_eval(label) + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + model_output = entry['model_output'] + print('model_output', model_output) + hk_label = float(label['hk']) + dev_label = float(label['dev']) + + # Extract model output values for HK and Dev enhancer activity + extracted_result = extract_numeric_values(model_output) + + # Handle missing or invalid data by assigning np.inf + if len(extracted_result) != 2: + + hk_result_values.append(np.inf) + dev_result_values.append(np.inf) + else: + hk_result = extracted_result[0] + dev_result = extracted_result[1] + hk_result_values.append(hk_result) + dev_result_values.append(dev_result) + + # Append the label values + hk_label_values.append(hk_label) + dev_label_values.append(dev_label) + + # Save processed task data for this entry + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_model_output': { + 'hk': hk_result if len(extracted_result) == 2 else np.inf, + 'dev': dev_result if len(extracted_result) == 2 else np.inf + }, + 'original_model_output': + entry['model_output'] + }) + + # Save the processed task data + # save_processed_data(model_name, task_name, task_processed_data) + + # Convert to numpy arrays for easier manipulation + hk_result_values = np.array(hk_result_values) + dev_result_values = np.array(dev_result_values) + hk_label_values = np.array(hk_label_values) + dev_label_values = np.array(dev_label_values) + + # Filter out NaN or inf values in both HK and Dev result/label pairs + hk_valid_mask = np.isfinite(hk_result_values) & np.isfinite( + hk_label_values) + dev_valid_mask = np.isfinite(dev_result_values) & np.isfinite( + dev_label_values) + + # Filter the valid HK and Dev values + hk_result_values = hk_result_values[hk_valid_mask] + hk_label_values = hk_label_values[hk_valid_mask] + dev_result_values = dev_result_values[dev_valid_mask] + dev_label_values = dev_label_values[dev_valid_mask] + + # Compute Pearson correlation for valid HK and Dev enhancer activities + if len(hk_result_values) > 0: + try: + hk_pcc, _ = pearsonr(hk_result_values, hk_label_values) + except Exception: + hk_pcc = np.inf # Set to inf in case of errors + else: + return { + 'PCC': + 'Error: HK has insufficient valid data ' + 'after removing NaNs and infs.' + } + if len(dev_result_values) > 0: + try: + dev_pcc, _ = pearsonr(dev_result_values, dev_label_values) + except Exception: + dev_pcc = np.inf # Set to inf in case of errors + else: + return { + 'PCC': + 'Error: Dev has insufficient valid data ' + 'after removing NaNs and infs.' + } + + # Combine results with NaN/inf values consideration + total_hk_points = len(hk_result_values) + np.sum(~hk_valid_mask) + total_dev_points = len(dev_result_values) + np.sum(~dev_valid_mask) + + # Assign mixed score with 0 for invalid entries + final_hk_pcc = (hk_pcc * len(hk_result_values) + 0 * np.sum(~hk_valid_mask) + ) / total_hk_points if len(hk_result_values) > 0 else 0 + final_dev_pcc = (dev_pcc * len(dev_result_values) + + 0 * np.sum(~dev_valid_mask)) / total_dev_points if len( + dev_result_values) > 0 else 0 + print('over_len:', over_len) + print('miss_len: ', miss_len) + return { + 'PCC': (final_hk_pcc + final_dev_pcc) / 2, + 'hk_PCC': final_hk_pcc, + 'dev_PCC': final_dev_pcc + } + + +# Process binary classification task +def process_binary_classification_task(task_name, task_entries, model_name): + label_classes = [] + result_classes = [] + task_processed_data = [] + entries_for_model = [] + over_len = 0 + miss_len = 0 + for index, entry in enumerate(tqdm(task_entries)): + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split( + '')[-1] + + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + + label_class = 1 if entry['label'] == 'positive' else 0 + model_output = entry['model_output'] + model_output = str(entry['model_output']) + result_class = None + score = 0 + + if model_output is None: + result_class = 1 - label_class + else: + keyword_result = classify_by_keywords(model_output) + if keyword_result == 'dont_know': + result_class = 1 - label_class + elif keyword_result is not None: + result_class = keyword_result + else: + if model_output and model_output.strip(): + entries_for_model.append({ + 'index': index, + 'text': model_output + }) + else: + result_class = 1 - label_class + + # 将已经处理完的条目先存起来,留出空位给模型处理结果 + task_processed_data.append({ + 'input': entry['input'], + 'original_label': entry['label'], + 'processed_label': label_class, + 'original_model_output': model_output, + 'processed_model_output': result_class, # 可能为None,后面会填充 + 'score': 'N/A' # 默认为N/A + }) + print(len(entries_for_model)) + + if entries_for_model: + + texts_to_classify = [item['text'] for item in entries_for_model] + + # 一次性将所有文本传给模型 + model_results = classify_by_sentiment_model(texts_to_classify) + + for i, model_item in enumerate(tqdm(entries_for_model)): + original_index = model_item['index'] + result_class, score = model_results[i] + + # (可选逻辑) 如果置信度低,则判错 + # if score < 0.5: + # result_class = + # 1 - task_processed_data[original_index]['processed_label'] + + # 将模型处理的结果填回到最终数据列表的正确位置 + task_processed_data[original_index][ + 'processed_model_output'] = result_class + task_processed_data[original_index]['score'] = str(score) + + result_classes = [d['processed_model_output'] for d in task_processed_data] + label_classes = [d['processed_label'] for d in task_processed_data] + print('miss_len:', miss_len) + print('over_len:', over_len) + + # save_processed_data(model_name, task_name, task_processed_data) + + return label_classes, result_classes + + +# Compute matthews correlation coefficient (MCC) +def compute_MCC(label_classes, result_classes): + if len(result_classes) == 0: + return {'MCC': 'Error: Empty data.'} + elif len(result_classes) != len(label_classes): + return { + 'MCC': 'Error: Mismatch in the number of extracted numeric values.' + } + else: + mcc = matthews_corrcoef(label_classes, result_classes) + return {'MCC': mcc} + + +# Compute accuracy score (Acc) +def compute_Acc(label_classes, result_classes): + if len(result_classes) == 0: + return { + 'Acc': + 'Error: Insufficient data for classification. ' + 'Number of model outputs is 0.' + } + elif len(result_classes) != len(label_classes): + return { + 'Acc': + 'Error: Mismatched labels. ' + 'The number of model outputs does not match the number of labels.' + } + else: + acc = accuracy_score(label_classes, result_classes) + return {'Acc': acc} + + +# Extract RNA family from the text +def extract_rna_family(text): + for rna_class in RNA_CLASSES: + if rna_class in text: + return rna_class + return None + + +# Compute ACC metric for NoncodingRNAFamily multiclass classification task +def compute_Acc_for_NoncodingRNAFamily_task(task_name, task_entries, + model_name): + correct_count = 0 + total_count = 0 + task_processed_data = [] + over_len = 0 + miss_len = 0 + for entry in task_entries: + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + result_family = extract_rna_family(entry['model_output']) + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + # result_family = "None" + result_family = extract_rna_family(entry['model_output']) + + label_family = entry['label'] + # result_family = extract_rna_family(entry["model_output"]) + # Compare extracted family with the ground truth label + if result_family == label_family: + correct_count += 1 + + total_count += 1 + + # Store original and processed data + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_model_output': + result_family, + 'original_model_output': + entry['model_output'] + }) + + # save_processed_data(model_name, task_name, task_processed_data) + print('over_len:', over_len) + print('miss_len:', miss_len) + # Calculate accuracy + accuracy = correct_count / total_count if total_count > 0 else 0 + if (total_count - over_len) != 0: + print('true_acc:', correct_count / (total_count - over_len)) + + return {'Acc': accuracy} + + +# Extract RNA modification labels from the output text +def extract_modifications(text): + extracted_modifications = [] + for mod_class in modification_classes: + # Use word boundaries to ensure whole-word match + if re.search(rf'\b{mod_class}\b', text): + extracted_modifications.append(mod_class) + return extracted_modifications + + +# Convert modification labels to a binary multihot vector +def convert_to_binary_vector(modifications, classes=modification_classes): + binary_vector = [] + + # Handle case where modifications is None + if modifications is None: + modifications = [] # Treat None as an empty list + + for mod in classes: + if mod in modifications: + binary_vector.append(1) + else: + binary_vector.append(0) + return binary_vector + + +# Compute AUC metrics for Modification task +def compute_AUC_for_Modification_task(task_name, task_entries, model_name): + y_true = [] + y_pred = [] + task_processed_data = [] + over_len = 0 + miss_len = 0 + for entry in task_entries: + # MARK:gaile + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split( + '')[-1] + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + predicted_modifications = extract_modifications(entry['model_output']) + # print(predicted_modifications) + true_modifications = entry['label'].split(',') + + # Handle case where result is empty and label is "none" + if not predicted_modifications: + # Classify by keyword + predicted_modifications = classify_by_keywords( + entry['model_output']) + + # If keyword negative, + # assigned to prediction to be the "none" class + if predicted_modifications == 0: + predicted_modifications = ['none'] + + elif predicted_modifications == 1: + predicted_modifications = [] + + # If the result cannot be classified, use the sentiment model + elif predicted_modifications is None: + + sentiment_result, sentiment_score = \ + classify_by_sentiment_model( + [entry['model_output']])[0] + + # If classified as negative, manually label as 'none' + if sentiment_result == 0: + predicted_modifications = ['none'] + + else: + predicted_modifications = [] + + # Convert the predicted and true modifications to binary vectors + y_true.append(convert_to_binary_vector(true_modifications)) + y_pred.append(convert_to_binary_vector(predicted_modifications)) + + # Store the processed data + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_model_ouput': + predicted_modifications, + 'original_model_output': + entry['model_output'] + }) + print('label', entry['label']) + print('predication', predicted_modifications) + + # save_processed_data(model_name, task_name, task_processed_data) + print('over_len:', over_len) + print('miss_len: ', miss_len) + # Compute the AUC for each class, then average the AUC across all classes + try: + auc = roc_auc_score(y_true, y_pred, average='macro') + print('auc', auc) + except ValueError: + auc = None + + return {'AUC': auc} + + +# FunctionEC Task +# Modified from +# SaProt https://github.com/westlake-repl/SaProt/blob/main/utils/metrics.py +def count_f1_max(pred, target): + """ + F1 score with the optimal threshold. + Handles cases where either predictions or targets are empty. + + Parameters: + pred (Tensor): predictions of shape :math:`(B, N)` + target (Tensor): binary targets of shape :math:`(B, N)` + + Returns: + float: The maximum F1 score or 0.0 if inputs are empty. + """ + # Check if either pred or target is empty + if pred.numel() == 0 or target.numel() == 0: + return 0.0 + + # Proceed with the original logic if inputs are not empty + order = pred.argsort(descending=True, dim=1, stable=True) + # print(f"order: {order}") + target = target.gather(1, order) + precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) + recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) + + is_start = torch.zeros_like(target).bool() + is_start[:, 0] = 1 + is_start = torch.scatter(is_start, 1, order, is_start) + all_order = pred.flatten().argsort(descending=True, stable=True) + order = order + torch.arange( + order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] + order = order.flatten() + inv_order = torch.zeros_like(order) + inv_order[order] = torch.arange(order.shape[0], device=order.device) + is_start = is_start.flatten()[all_order] + all_order = inv_order[all_order] + + precision = precision.flatten() + recall = recall.flatten() + + all_precision = precision[all_order] - \ + torch.where( + is_start, torch.zeros_like(precision), + precision[all_order - 1]) + all_precision = all_precision.cumsum(0) / is_start.cumsum(0) + all_recall = recall[all_order] - \ + torch.where( + is_start, torch.zeros_like(recall), + recall[all_order - 1]) + all_recall = all_recall.cumsum(0) / pred.shape[0] + all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + + 1e-10) + + if torch.isnan(all_f1).any(): + return 0.0 + + return all_f1.max() + + +def round_and_scale_results(data, decimal_places=3, scale_factor=100): + for key, value in data.items(): + if isinstance(value, dict): + # Recursive call if the value is a dictionary + round_and_scale_results(value, decimal_places, scale_factor) + elif isinstance(value, (float, int)): + # Round and scale numeric values + data[key] = float(round(value * scale_factor, decimal_places)) + + +# Convert EC number to binary multihot vectors +def ec_to_multihot(ec_list, ec_labels): + multihot = torch.zeros(len(ec_labels)) + if not ec_list: # Check if ec_list is empty + return multihot + multihot = torch.zeros(len(ec_labels)) + for ec in ec_list: + if ec in ec_labels: + idx = ec_labels.index(ec) + multihot[idx] = 1 + return multihot + + +# Compute Fmax metric for FunctionEC task +def compute_Fmax_for_FunctionEC_task(task_name, task_entries, ec_labels, + model_name): + all_preds = [] + all_labels = [] + task_processed_data = [] + over_len = 0 + miss_len = 0 + for entry in task_entries: + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split('')[-1] + else: + if '' in entry['model_output']: + over_len += 1 + else: + miss_len += 1 + if '' in entry['model_output']: + entry['model_output'] = entry['model_output'].split( + '')[-1] + # Parse the EC numbers from 'output' and 'label' + label_ec = re.findall(r'\d+\.\d+\.\d+\.\-?\d*', entry['label']) + result_ec = re.findall(r'\d+\.\d+\.\d+\.\-?\d*', + str(entry['model_output'])) + + # Convert EC numbers to multi-hot vectors + pred_multihot = ec_to_multihot(result_ec, ec_labels) + label_multihot = ec_to_multihot(label_ec, ec_labels) + + # Store the results + all_preds.append(pred_multihot) + all_labels.append(label_multihot) + + # Save processed task data + task_processed_data.append({ + 'input': + entry['input'], + 'label': + entry['label'], + 'processed_label': + label_ec, + 'original_model_output': + entry['model_output'], + 'processed_model_output': + result_ec, + }) + print('label_ec', label_ec) + print('result_ec', result_ec) + + # save_processed_data(model_name, task_name, task_processed_data) + + # # Stack the predictions and targets for batch processing + all_preds = torch.stack(all_preds) + all_labels = torch.stack(all_labels) + print('miss_len: ', miss_len) + print('over_len: ', over_len) + # Compute the Fmax score + try: + fmax_score = count_f1_max(all_preds, all_labels) + except ValueError: + fmax_score = None + + return {'Fmax': fmax_score.item()} + + +def preprocess_input_data(input_file_path, prediction, mini_set=False): + data = [] + # Open the input file and process each line + + with open(input_file_path, 'r') as f: + data_in = json.load(f) + if mini_set and len(data_in) > 150: + import random + random.seed(1024) + data_in = random.sample(data_in, 150) + random.seed() + + if len(prediction) == len(data_in): + for index in range(len(data_in)): + try: + data_list = {} + data_list['input'] = data_in[index]['input'] + data_list['output'] = data_in[index]['output'] + # Try to load the line as a JSON object + + data_list['model_output'] = prediction[index] + data_list['label'] = data_in[index]['label'] + # data_list['label']=data_in[index]['label'] + + data_list['task'] = data_in[index]['task'] + # data_list['task']=data_in[index]['task'] + data.append(data_list) + # Ensure the parsed data is a dictionary + except json.JSONDecodeError: + print(f'Skipping invalid line: {data_in[index]}') + else: + print('len(prediction)!=len(data_in) !!!') + + df = pd.DataFrame(data) # Convert to a DataFrame + # df = pd.read_json(input_file_path, lines=True, encoding_errors="ignore") + print(f'Number of data samples: {len(df)}') + df.rename(columns={'result': 'model_output'}, inplace=True) + print(df['task']) + df['task'] = df['task'].replace('rna_protein_interaction', + 'ncRNAProteinInter') + df['task'] = df['task'].replace('antibody_antigen', 'AntibodyAntigen') + # Process entries with null labels + # null_label_df = df[df['label'].isna()] + # # null_label_df.to_json(f"{model_name}_result_label_null.json", + # orient='records', lines=True) + + # Remove data for _all task + # df = df[~df['task'].str.endswith('_all')] + + # Replace 'tf-h' with 'tf_h' and 'tf-m' with 'tf_m' in the 'task' column + df['task'] = df['task'].str.replace('tf-h', 'tf_h') + df['task'] = df['task'].str.replace('tf-m', 'tf_m') + + # Keep data if label is not null + df = df[df['label'].notna()] + df.reset_index(inplace=True, drop=True) + + # Convert to dictionary format for grouping + data = df.to_dict(orient='records') + + # Group the data by 'task' + grouped_data = defaultdict(list) + for entry in data: + task_name = entry['task'].split('-')[0] + grouped_data[task_name].append(entry) + + return grouped_data + + +class bio_instruction_Evaluator(BaseEvaluator): + + def __init__(self, + path, + task, + model_name, + mini_set=False, + *args, + **kwargs): + super().__init__(*args, **kwargs) + + path = get_data_path(path) + self.dataset_path = os.path.join(path, f'{task}/test/data.json') + self.model_name = model_name + self.mini_set = mini_set + + def score(self, predictions): + test_path = self.dataset_path + repo_id = '/'.join(test_path.split('/')[:-3]) + ec_path = 'ec_labels.json' + ec_file_path = os.path.join(repo_id, ec_path) + # ec_file_path = hf_hub_download(repo_id, ec_path, repo_type="dataset") + + with open(ec_file_path, 'r') as f: + ec_labels = json.load(f) + + test_path = test_path.split(repo_id + '/')[1] + input_file_path = self.dataset_path + # input_file_path = + # hf_hub_download(repo_id, test_path, repo_type="dataset") + + grouped_data = preprocess_input_data(input_file_path, + predictions, + mini_set=self.mini_set) + + print(f'Grouped data for tasks: {list(grouped_data.keys())}') + + register_tasks_path = 'register_tasks.json' + register_tasks_file_path = os.path.join(repo_id, register_tasks_path) + # register_tasks_file_path = + # hf_hub_download(repo_id, register_tasks_path, repo_type="dataset") + with open(register_tasks_file_path, 'r') as f: + task_type_data = json.load(f) + + metrics = {} + + # Loop over tasks + for task_name, task_entries in grouped_data.items(): + task_type = task_type_data[task_name]['type'] + task_metrics = task_type_data[task_name]['metrics'] + print(f'Prosessing {task_name} task...') + print(task_type) + sys.stdout.flush() + + if task_type == 'regression': + # task_processed_data, label_values, result_values + # = process_regression_task(task_name, task_entries) + label_values, result_values = process_regression_task( + task_name, task_entries, self.model_name) + if task_metrics == 'spearman': + metrics[task_name] = compute_spearman( + label_values, result_values) + + elif task_metrics == 'R2': + metrics[task_name] = compute_R2(label_values, + result_values) + # print(metrics[task_name]) + + elif task_metrics == 'mixed_score': + metrics[task_name] = compute_mixed_score(label_values, + result_values, + threshold=30) + + elif task_type == 'binary classification': + # task_processed_data, label_classes, result_classes + # = process_binary_classification_task(task_name, task_entries) + label_classes, result_classes = \ + process_binary_classification_task( + task_name, task_entries, self.model_name) + print(f'label_classes: {label_classes}') + print(f'result_classes: {result_classes}') + if task_metrics == 'MCC': + metrics[task_name] = compute_MCC(label_classes, + result_classes) + + elif task_metrics == 'Acc': + metrics[task_name] = compute_Acc(label_classes, + result_classes) + + elif task_type == 'multilabel regression': + + if task_name == 'ProgrammableRNASwitches': + metrics[task_name] = \ + compute_R2_for_ProgrammableRNASwitches_task( + task_name, task_entries, self.model_name) + + elif task_name == 'enhancer_activity': + metrics[ + task_name] = compute_PCC_for_enhancer_activity_task( + task_name, task_entries, self.model_name) + + elif task_type == 'multiclass classification': + + if task_name == 'NoncodingRNAFamily': + metrics[ + task_name] = compute_Acc_for_NoncodingRNAFamily_task( + task_name, task_entries, self.model_name) + + elif task_type == 'multilabel classification': + if task_name == 'FunctionEC': + metrics[task_name] = compute_Fmax_for_FunctionEC_task( + task_name, task_entries, ec_labels, self.model_name) + + elif task_name == 'Modification': + metrics[task_name] = compute_AUC_for_Modification_task( + task_name, task_entries, self.model_name) + + print(f'The metrics {task_metrics} for task {task_name}' + f' is {str(metrics[task_name][task_metrics])}') + sys.stdout.flush() + + metrics_grouped_by_omics = defaultdict(dict) + + for task_name, task_metrics in metrics.items(): + # Get the omics type from task_type_data + omics = task_type_data[task_name]['omics'] + + # Scale the metrics + scaled_metrics = task_metrics.copy( + ) # Make a copy to avoid modifying the original + round_and_scale_results( + scaled_metrics) # Apply scaling to the metrics + + # Add the scaled metrics to the grouped dictionary + metrics_grouped_by_omics[omics][task_name] = scaled_metrics + + # Save the metrics (results) to a new JSON file + # metrics_file_path = ( + # path_bioinstruction + f'/metrics_result/{omics}/' + + # f'metrics_result_{self.model_name}_{task_name}.json') + # output_directory = os.path.dirname(metrics_file_path) + # os.makedirs(output_directory, exist_ok=True) + # with open(metrics_file_path, 'w') as outfile: + # json.dump(metrics_grouped_by_omics[omics], outfile, indent=4) + # print(f'Metrics saved to {metrics_file_path}') + + return metrics_grouped_by_omics[omics][task_name] diff --git a/opencompass/datasets/SciReasoner/bulk_modulus_material.py b/opencompass/datasets/SciReasoner/bulk_modulus_material.py new file mode 100644 index 000000000..bbc0f0384 --- /dev/null +++ b/opencompass/datasets/SciReasoner/bulk_modulus_material.py @@ -0,0 +1,173 @@ +# flake8: noqa + +import json +import os +import re +from collections import Counter +from typing import List, Union + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +@LOAD_DATASET.register_module() +class Bulk_modulus_material_Dataset(BaseDataset): + + @staticmethod + def load(path, mini_set=False): + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'bulk_modulus_material/dev/data.json') + test_path = os.path.join(path, f'bulk_modulus_material/test/data.json') + + # load from local json file + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def material_postprocessor(text: Union[str, None]) -> str: + """提取 标签内容""" + if not text: + return '' + match = re.search(r'(.*?)', text, + re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + return '' + + +class material_Evaluator(BaseEvaluator): + """ + Evaluator for: + - SMAct validity + - Composition precision (based on output-extracted elements) + - Exact match (between prediction and reference block) + """ + + def __init__(self, data_path=None, **kwargs): + super().__init__() + self.data_path = os.path.join(get_data_path(data_path), + 'bulk_modulus_material/test/data.json') + self.prompt_elements_list = [] # 从 gt 提取的元素 + self.reference_materials = [] # exact match 的参考答案 + + if self.data_path: + self._load_ground_truths() + + def _load_ground_truths(self): + """加载 ground truth 元素和材料""" + with open(self.data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + for item in data: + output = item.get('output', '') + # 提取组成元素 + elements = re.findall(r'\b[A-Z][a-z]?\b', + material_postprocessor(output)) + self.prompt_elements_list.append(elements) + # 提取完整材料块用于 exact match + self.reference_materials.append(material_postprocessor(output)) + + def _normalize(self, formula: str) -> str: + """标准化化学式(字母排序+数量)""" + tokens = re.findall(r'([A-Z][a-z]?)(\d*)', formula) + tokens.sort(key=lambda x: x[0]) + return ''.join(f"{el}{cnt or ''}" for el, cnt in tokens) + + def score(self, predictions: List[dict]): + + from smact.screening import smact_validity + + total = len(predictions) + format_valid = 0 + smact_valid = 0 + precision_sum = 0.0 + exact_match_count = 0 + + for i, item in enumerate(predictions): + if isinstance(item, str): + item = {'prediction': item} + text = item.get('prediction', '').strip() + + # --- SMAct validity --- + match = re.match( + r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)\s*(?:\s*)?', + text) + if match: + elements_str, _ = match.groups() + elements = elements_str.split() + counter = Counter(elements) + formula = ''.join(f"{el}{cnt or ''}" + for el, cnt in sorted(counter.items())) + try: + if smact_validity(formula): + smact_valid += 1 + format_valid += 1 + except Exception: + pass + + # --- Composition precision --- + if i < len(self.prompt_elements_list): + gt_elements = set(self.prompt_elements_list[i]) + pred_elements = set(re.findall(r'\b[A-Z][a-z]?\b', text)) + correct = len(gt_elements & pred_elements) + if gt_elements: + precision_sum += correct / len(gt_elements) + + # --- Exact Match --- + if i < len(self.reference_materials): + pred_mat = material_postprocessor(text) + gt_mat = self.reference_materials[i] + if pred_mat == gt_mat: + exact_match_count += 1 + + avg_precision = (precision_sum / total * 100) if total else 0.0 + smact_in_format = (smact_valid / format_valid * + 100) if format_valid else 0.0 + smact_in_all = (smact_valid / total * 100) if total else 0.0 + exact_match_ratio = (exact_match_count / total * 100) if total else 0.0 + + return { + 'total_samples': total, + 'format_valid_count': format_valid, + 'smact_valid_count': smact_valid, + 'smact_validity_ratio_in_format_valid_%': smact_in_format, + 'smact_validity_ratio_in_all_%': smact_in_all, + 'average_precision_%': avg_precision, + 'exact_match_ratio_%': exact_match_ratio + } diff --git a/opencompass/datasets/SciReasoner/composition_material.py b/opencompass/datasets/SciReasoner/composition_material.py new file mode 100644 index 000000000..34aaf7693 --- /dev/null +++ b/opencompass/datasets/SciReasoner/composition_material.py @@ -0,0 +1,228 @@ +# flake8: noqa + +import json +import os +import re +from collections import Counter +from typing import Union + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + + +def extract_elements_from_prompt(prompt: str) -> list: + """ + Extract element symbols from diverse prompt instructions. + Supported patterns include: + - composed of + - that has + - characterized by + - with the composition + - based on + - featuring + - whose makeup is + """ + patterns = [ + r'composed of', r'that has', r'characterized by', + r'with the composition', r'based on', r'featuring', r'whose makeup is' + ] + + joined = '|'.join(patterns) + match = re.search(rf'(?:{joined})\s+(.*?)(?:[\.。\n]|$)', prompt, + re.IGNORECASE) + + if match: + elements_str = match.group(1) + elements = [ + el.strip() for el in re.split(r'[,\s]+', elements_str) + if re.fullmatch(r'[A-Z][a-z]?', el.strip()) + ] + return elements + + # fallback: 尝试提取所有可能的元素符号 + fallback = re.findall(r'\b[A-Z][a-z]?\b', prompt) + return fallback + + +def composition_precision(elements: list[str], prediction: str) -> float: + """计算元素命中率""" + E_pi = set(elements) + clean = re.sub(r'<[^>]+>', ' ', prediction) + E_gi = set(re.findall(r'\b[A-Z][a-z]?\b', clean)) + if not E_pi: + return 0.0 + return len(E_pi & E_gi) / len(E_pi) + + +@LOAD_DATASET.register_module() +class Composition_material_Dataset(BaseDataset): + + @staticmethod + def load(path, mini_set=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join( + path, f'conditional_generation/composition_material/dev/data.json') + test_path = os.path.join( + path, + f'conditional_generation/composition_material/test/data.json') + + # load from local json file + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + train_data = train_data[:5] + # Limit the dataset to 5 samples for testing purposes + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 150) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def material_postprocessor(text: Union[str, None]) -> str: + """提取 标签内容""" + if not text: + return '' + match = re.search(r'(.*?)', text, + re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + return '' + + +class composition_Evaluator(BaseEvaluator): + + def __init__(self, data_path, tuning_data=None, **kwargs): + super().__init__() + self.data_path = os.path.join( + get_data_path(data_path), + 'conditional_generation/composition_material/test/data.json') + self.prompts = [] + self.gt_materials = set() + + if self.data_path: + self._load_original_inputs() + + def _load_original_inputs(self): + with open(self.data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + self.prompts = [item.get('input', '') for item in data] + + for item in data: + output = item.get('output', '') + mat = material_postprocessor(output) + if mat: + self.gt_materials.add(mat.strip()) + + def _normalize(self, formula): + tokens = re.findall(r'([A-Z][a-z]?)(\d*)', formula) + tokens.sort(key=lambda x: x[0]) + return ''.join(f"{el}{cnt or ''}" for el, cnt in tokens) + + def score(self, predictions): + from smact.screening import smact_validity + + total = len(predictions) + format_valid = 0 + smact_valid = 0 + precision_sum = 0.0 + novel_count = 0 + + for i, item in enumerate(predictions): + if isinstance(item, str): + item = {'prediction': item} + + text = item.get('prediction', '').strip() + prompt = item.get('input', '').strip() + if not prompt and i < len(self.prompts): + prompt = self.prompts[i] + + prompt_elements = extract_elements_from_prompt(prompt) + + print('== Sample ==') + print('Prompt:', prompt) + print('Prompt Elements:', prompt_elements) + print('Prediction Text:', text[:200]) + + # --- SMAct validity --- + match = re.match( + r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)\s*(?:\s*)?', + text) + if match: + elements_str, _ = match.groups() + elements = elements_str.split() + counter = Counter(elements) + formula = ''.join(f"{el}{cnt or ''}" + for el, cnt in sorted(counter.items())) + try: + if smact_validity(formula): + smact_valid += 1 + format_valid += 1 + except Exception: + pass + + # --- Composition precision --- + if prompt_elements: + precision_sum += composition_precision(prompt_elements, text) + + # --- Novelty --- + predicted_material = material_postprocessor(text) + if not predicted_material: + predicted_material = text.strip() + + if predicted_material: + print(f'[Novelty Check] GT materials: {self.gt_materials}') + print( + f'[Novelty Check] Predicted material: {predicted_material}' + ) + if predicted_material not in self.gt_materials: + novel_count += 1 + print('[Novelty] Novel') + else: + print('[Novelty] Seen before') + + avg_precision = (precision_sum / total * 100) if total else 0.0 + smact_in_format = (smact_valid / format_valid * + 100) if format_valid else 0.0 + smact_in_all = (smact_valid / total * 100) if total else 0.0 + novelty_ratio = (novel_count / total * 100) if total else 0.0 + + return { + 'total_samples': total, + 'format_valid_count': format_valid, + 'smact_valid_count': smact_valid, + 'smact_validity_ratio_in_format_valid_%': smact_in_format, + 'smact_validity_ratio_in_all_%': smact_in_all, + 'average_precision_%': avg_precision, + 'novel_material_ratio_%': novelty_ratio, + } diff --git a/opencompass/datasets/SciReasoner/opi/__init__.py b/opencompass/datasets/SciReasoner/opi/__init__.py new file mode 100644 index 000000000..a7eba4446 --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +from .config import TASKS as opi_TASKS # noqa: F401, F403 +from .config import \ + TASKS_GENERATION_SETTINGS as \ + opi_TASKS_GENERATION_SETTINGS # noqa: F401, F403 +from .evaluator import opi_postprocess # noqa: F401, F403 +from .evaluator import Opi_Evaluator, OpiDataset diff --git a/opencompass/datasets/SciReasoner/opi/config.py b/opencompass/datasets/SciReasoner/opi/config.py new file mode 100644 index 000000000..2f42c689d --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/config.py @@ -0,0 +1,136 @@ +TASKS = ( + 'EC_number', + 'Subcellular_localization', + 'Fold_type', + 'Keywords', + 'GO', + 'Function', + 'gSymbol2Tissue', + 'gSymbol2Cancer', + 'gName2Cancer', +) + +DEFAULT_MAX_INPUT_TOKENS = 512 +DEFAULT_MAX_NEW_TOKENS = 400 + +TASKS_GENERATION_SETTINGS = { + 'EC_number': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'Subcellular_localization': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'Fold_type': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'Keywords': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'GO': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'Function': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'gSymbol2Tissue': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'gSymbol2Cancer': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, + 'gName2Cancer': { + 'generation_kargs': { + 'num_return_sequences': 1, + 'num_beams': 1, + 'temperature': 0.2, + 'top_k': 50, + 'top_p': 0.75, + 'do_sample': True, + }, + }, +} + +TASK_TAGS = { + 'EC_number': ('', ''), + 'Subcellular_localization': ('', ''), + 'Fold_type': ('', ''), + 'Keywords': ('', ''), + 'GO': ('', ''), + 'Function': ('', ''), + 'gSymbol2Tissue': ('', ''), + 'gSymbol2Cancer': ('', ''), + 'gName2Cancer': ('', ''), +} + +# These tasks output SMILES, where there may be semicolons +# that separate different parts. +# To facilitate evaluation, each semicolon is replaced by a dot. +TASKS_WITH_SEMICOLON_REPLACE = ('Keywords', 'GO') + +# For these tasks, one input might have multiple gold answers, +# so the gold answer should be directly obtained from the dataset +# instead of directly using the gold domain of each sample. +TASKS_WITH_READING_GOLD_FROM_DATASET = TASKS + +BASE_MODELS = { + 'osunlp/LlaSMol-Mistral-7B': 'mistralai/Mistral-7B-v0.1', + 'osunlp/LlaSMol-Galactica-6.7B': 'facebook/galactica-6.7b', + 'osunlp/LlaSMol-Llama2-7B': 'meta-llama/Llama-2-7b-hf', + 'osunlp/LlaSMol-CodeLlama-7B': 'codellama/CodeLlama-7b-hf', +} diff --git a/opencompass/datasets/SciReasoner/opi/evaluator.py b/opencompass/datasets/SciReasoner/opi/evaluator.py new file mode 100644 index 000000000..a7b62a135 --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/evaluator.py @@ -0,0 +1,289 @@ +# flake8: noqa +# opencompass/datasets/opi/evaluator.py + +import json +import os +import re + +from datasets import Dataset, DatasetDict +from huggingface_hub import hf_hub_download + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_data_path + +from .utils.metrics4all import calculate_metrics, calculate_rouge_l + + +@LOAD_DATASET.register_module() +class OpiDataset(BaseDataset): + + @staticmethod + def load(path, task, max_cut=-1, mini_set=False, hf_hub=False): + + # if (hf_hub is True): + # # load from huggingface hub + # train_data = [] + # repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1] + # train_path = train_path.split(repo_id + '/')[1] + # test_path = test_path.split(repo_id + '/')[1] + # + # train_path = hf_hub_download(repo_id, + # train_path, + # repo_type='dataset') + # test_path = hf_hub_download(repo_id, + # test_path, + # repo_type='dataset') + + path = get_data_path(path) + train_path = os.path.join(path, f'{task}/dev/data.json') + test_path = os.path.join(path, f'{task}/test/data.json') + + with open(train_path, 'r', encoding='utf-8') as f: + train_data = json.load(f) + with open(test_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + + # train_data = train_data[:10] + # # Limit the dataset to 10 samples for testing purposes + # test_data = test_data[:10] + if mini_set: + import random + random.seed(1024) + test_data = random.sample(test_data, 50) + random.seed() + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +def extract_answer_part(outputs, left_tag, right_tag, mode='tag'): + assert mode in ('tag', 'direct') + assert isinstance(outputs, list) + + answers = [] + for text in outputs: + if mode == 'direct' or (left_tag is None and right_tag is None): + text = text.replace('', '').replace('', '').strip() + answers.append(text.strip()) + continue + + left_tag_pos = text.find(left_tag) + if left_tag_pos == -1: + answers.append('') + continue + right_tag_pos = text.find(right_tag) + if right_tag_pos == -1: + answers.append('') + continue + text = text[left_tag_pos + len(left_tag):right_tag_pos].strip() + answers.append(text) + return answers + + +@TEXT_POSTPROCESSORS.register_module('opi_postprocess') +def opi_postprocess(text, task, *args, **kwargs): + print(f'task: {task}, text: {text}') + text = text.strip() + text = re.sub(r'<\|endoftext\|>', '', text) + text = re.sub(r'<\|im_end\|>', '', text) + return text + + +class Opi_Evaluator(BaseEvaluator): + + def __init__(self, task, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task = task + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different length' + } + + if not isinstance(predictions[0], list): + predictions = [[pred] for pred in predictions] + if not isinstance(references[0], list): + references = [[ref] for ref in references] + + if self.task == 'Function': + return self._evaluate_function(predictions, references) + elif self.task == 'Subcellular_localization': + return self._evaluate_subcellular_localization( + predictions, references) + elif self.task == 'Fold_type': + return self._evaluate_fold_type(predictions, references) + elif self.task in ('EC_number', 'GO', 'Keywords', 'gSymbol2Tissue', + 'gSymbol2Cancer', 'gName2Cancer'): + return self._evaluate_multilabel(predictions, references) + else: + return self._evaluate_general(predictions, references) + + def _evaluate_function(self, predictions, references): + """评估功能描述任务,使用 ROUGE-L""" + # if not METRICS_AVAILABLE: + # return self._evaluate_text_similarity(predictions, references) + + rouge_ls = [] + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].strip() + ref = ref_list[0].strip() + + # 确保输出和目标是列表格式 + if isinstance(pred, str): + pred = [pred] + if isinstance(ref, str): + ref = [ref] + + rouge_l = calculate_rouge_l(pred, ref) + rouge_ls.append(rouge_l) + + mean_rouge_l = sum(rouge_ls) / len(rouge_ls) if rouge_ls else 0 + return { + 'ROUGE-L': round(mean_rouge_l, 4), + # 'total': len(predictions) + } + + def _evaluate_subcellular_localization(self, predictions, references): + """评估亚细胞定位任务,使用准确率""" + # if not METRICS_AVAILABLE: + # return self._evaluate_general(predictions, references) + + accuracies = [] + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].strip() + ref = ref_list[0].strip() + + # 确保输出和目标是列表格式 + if isinstance(pred, str): + pred = [pred] + if isinstance(ref, str): + ref = [ref] + + accuracy, _, _, _ = calculate_metrics(pred, ref) + accuracies.append(accuracy) + + mean_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0 + return { + 'Accuracy': round(mean_accuracy, 4), + # 'total': len(predictions) + } + + def _evaluate_fold_type(self, predictions, references): + """评估折叠类型任务,使用与 accuracy4fold_type.py 相同的计算方式""" + # 初始化计数器 + correct_predictions = 0 + total_predictions = 0 + + # 评估每个预测结果 + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].strip() + ref = ref_list[0].strip() + + # 直接比较预测值和真实值 + if pred == ref: + correct_predictions += 1 + total_predictions += 1 + + # 计算准确率 + accuracy = correct_predictions / total_predictions \ + if total_predictions > 0 else 0 + + return { + 'Accuracy': round(accuracy, 4), + # 'correct': correct_predictions, + # 'total': total_predictions + } + + def _evaluate_multilabel(self, predictions, references): + """评估多标签任务(EC_number, GO, Keywords)""" + # if not METRICS_AVAILABLE: + # return self._evaluate_general(predictions, references) + + precisions = [] + recalls = [] + f1_scores = [] + + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].strip() + ref = ref_list[0].strip() + + # if isinstance(pred, str): + # pred = re.split(r'[;,,;]\s*', pred) + # if isinstance(ref, str): + # ref = re.split(r'[;,,;]\s*', ref) + if isinstance(pred, str): + pred = [ + p.strip() for p in re.split(r'[;,,;]\s*', pred) + if p.strip() + ] + if isinstance(ref, str): + ref = [ + r.strip() for r in re.split(r'[;,,;]\s*', ref) + if r.strip() + ] + + # 过滤空字符串 + # pred = [p for p in pred if p.strip()] + # ref = [r for r in ref if r.strip()] + # import pdb; pdb.set_trace() + if ref: # 只有当参考标签不为空时才计算 + _, precision, recall, f1 = calculate_metrics(pred, ref) + precisions.append(precision) + recalls.append(recall) + f1_scores.append(f1) + + mean_precision = sum(precisions) / len(precisions) if precisions else 0 + mean_recall = sum(recalls) / len(recalls) if recalls else 0 + mean_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 + + return { + 'Precision': round(mean_precision, 4), + 'Recall': round(mean_recall, 4), + 'F1 Score': round(mean_f1, 4), + # 'total': len(predictions) + } + + def _evaluate_text_similarity(self, predictions, references): + """简单的文本相似度评估(当 ROUGE 不可用时)""" + correct = 0 + total = len(predictions) + + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].lower().strip() + ref = ref_list[0].lower().strip() + + # 简单的包含关系检查 + if pred == ref or pred in ref or ref in pred: + correct += 1 + + accuracy = correct / total if total > 0 else 0 + return { + 'Text_Similarity': round(accuracy, 4), + # 'correct': correct, + # 'total': total + } + + def _evaluate_general(self, predictions, references): + """通用评估方法""" + correct = 0 + total = len(predictions) + + for pred_list, ref_list in zip(predictions, references): + pred = pred_list[0].lower().strip() + ref = ref_list[0].lower().strip() + + if pred == ref: + correct += 1 + + accuracy = correct / total if total > 0 else 0 + return { + 'Accuracy': round(accuracy, 4), + # 'correct': correct, + # 'total': total + } diff --git a/opencompass/datasets/SciReasoner/opi/process_ec_numbers.py b/opencompass/datasets/SciReasoner/opi/process_ec_numbers.py new file mode 100644 index 000000000..e7ad5f485 --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/process_ec_numbers.py @@ -0,0 +1,62 @@ +import json +import re +from typing import Any + + +def add_spaces_to_ec_number(text: str) -> str: + """ + 在EC号码中添加空格,格式从 2.7.10.2 变为 2 . 7 . 10 . 2 + """ + # 匹配EC号码格式:数字.数字.数字.数字 + pattern = r'\b(\d+)\.(\d+)\.(\d+)\.(\d+)\b' + + def replace_ec(match): + return (f'{match.group(1)} . {match.group(2)} .', + f' {match.group(3)} . {match.group(4)}') + + return re.sub(pattern, replace_ec, text) + + +def process_json_value(value: Any) -> Any: + """ + 递归处理JSON值,在字符串中添加EC号码空格 + """ + if isinstance(value, str): + return add_spaces_to_ec_number(value) + elif isinstance(value, dict): + return {k: process_json_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [process_json_value(item) for item in value] + else: + return value + + +def process_ec_json_file(input_file: str, output_file: str) -> None: + """ + 处理JSON文件,将所有EC号码格式化为带空格的格式 + """ + try: + # 读取JSON文件 + with open(input_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 处理数据 + processed_data = process_json_value(data) + + # 写入新文件 + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(processed_data, f, ensure_ascii=False, indent=2) + + print(f'处理完成!已保存到 {output_file}') + + except Exception as e: + print(f'处理文件时出错: {e}') + + +if __name__ == '__main__': + input_file = \ + 'cot_data/EC_number_train_CLEAN_EC_number_train_train.json_final.json' + output_file = \ + 'cot_data/EC_number_train_CLEAN_EC_number_train_train_spaced.json' + + process_ec_json_file(input_file, output_file) diff --git a/opencompass/datasets/SciReasoner/opi/utils/accuracy4fold_type.py b/opencompass/datasets/SciReasoner/opi/utils/accuracy4fold_type.py new file mode 100644 index 000000000..e6d8bd0af --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/utils/accuracy4fold_type.py @@ -0,0 +1,45 @@ +import json +import os + +import tqdm + + +def load_json(file_path): + """Load JSON data from a file.""" + with open(file_path, 'r') as f: + return json.load(f) + + +def compute_accuracy4fold_type(eval_file, test_files): + """Compute accuracy for predictions against test datasets.""" + # Load evaluation data + eval_data = load_json(eval_file) + acc_dict = {} + # Iterate over each test file + for test_file in test_files: + # Load test data + test_data = load_json(test_file) + + # Create a set of test sequences + test_seq_set = {item['primary'] for item in test_data} + + # Initialize counters + correct_predictions = 0 + total_predictions = 0 + + # Evaluate each item in the evaluation data + for item in tqdm.tqdm(eval_data): + if item['input'] not in test_seq_set: + continue + predict = item.get('output', item.get('predict', [])) + label = item['target'] + if predict == label: + correct_predictions += 1 + total_predictions += 1 + + # Calculate and print accuracy + accuracy = correct_predictions / total_predictions \ + if total_predictions > 0 else 0 + acc_dict[os.path.basename(test_file).split('.')[0][21:]] = round( + accuracy, 4) + return acc_dict diff --git a/opencompass/datasets/SciReasoner/opi/utils/metrics4all.py b/opencompass/datasets/SciReasoner/opi/utils/metrics4all.py new file mode 100644 index 000000000..672c3718e --- /dev/null +++ b/opencompass/datasets/SciReasoner/opi/utils/metrics4all.py @@ -0,0 +1,143 @@ +import argparse +import json +import os + +import tqdm +from rouge_score import rouge_scorer +from sklearn.metrics import (accuracy_score, f1_score, precision_score, + recall_score) +from sklearn.preprocessing import MultiLabelBinarizer + +from .accuracy4fold_type import compute_accuracy4fold_type + + +def calculate_metrics(output, target): + # Convert to binary format + mlb = MultiLabelBinarizer(classes=sorted(set(output + target))) + y_true = mlb.fit_transform([target]) + y_pred = mlb.transform([output]) + + # Calculate metrics + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, + y_pred, + average='micro', + zero_division=0) + recall = recall_score(y_true, y_pred, average='micro', zero_division=0) + f1 = f1_score(y_true, y_pred, average='micro', zero_division=0) + + return accuracy, precision, recall, f1 + + +def calculate_rouge_l(output, target): + scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + scores = scorer.score(' '.join(target), ' '.join(output)) + return scores['rougeL'].fmeasure + + +def process_json_file(json_file_path): + accuracies = [] + precisions = [] + recalls = [] + f1_scores = [] + rouge_ls = [] + + with open(json_file_path, 'r') as file: + data = json.load(file) + + for entry in tqdm.tqdm(data): + output = entry.get('output', entry.get('predict', [])) + target = entry.get('target', []) + + # Ensure both output and target are lists + if isinstance(output, str): + if any(keyword in json_file_path for keyword in + ['EC_number', 'go_terms', 'keywords', 'gene', 'domain']): + output = output.split('; ') + elif any(keyword in json_file_path + for keyword in ['function', 'subcell_loc', 'ss']): + output = [output] + if isinstance(target, str): + if any(keyword in json_file_path for keyword in + ['EC_number', 'go_terms', 'keywords', 'gene', 'domain']): + target = target.split('; ') + elif any(keyword in json_file_path + for keyword in ['function', 'subcell_loc', 'ss']): + target = [target] + + if 'function' in json_file_path: + rouge_l = calculate_rouge_l(output, target) + rouge_ls.append(rouge_l) + elif 'subcell_loc' in json_file_path: + accuracy, _, _, _ = calculate_metrics(output, target) + accuracies.append(accuracy) + else: + _, precision, recall, f1 = calculate_metrics(output, target) + # accuracies.append(accuracy) + precisions.append(precision) + recalls.append(recall) + f1_scores.append(f1) + + if 'function' in json_file_path: + mean_rouge_l = sum(rouge_ls) / len(rouge_ls) if rouge_ls else 0 + return {'ROUGE-L': round(mean_rouge_l, 4)}, None + elif 'subcell_loc' in json_file_path: + mean_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0 + return {'Accuracy': round(mean_accuracy, 4)}, None + else: + mean_precision = sum(precisions) / len(precisions) if precisions else 0 + mean_recall = sum(recalls) / len(recalls) if recalls else 0 + mean_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 + return { + 'Precision': round(mean_precision, 4), + 'Recall': round(mean_recall, 4), + 'F1 Score': round(mean_f1, 4) + }, None + + +def main(eval_res_path): + results = {} + + # List all JSON files in the directory + for file_name in sorted(os.listdir(eval_res_path)): + if file_name.endswith('.json') and 'metrics_result' not in file_name: + print(f'Processing {file_name}') + file_path = os.path.join(eval_res_path, file_name) + if 'function' in file_path: + metrics, _ = process_json_file(file_path) + results[file_name] = {'ROUGE-L': metrics['ROUGE-L']} + elif 'subcell' in file_path: + metrics, _ = process_json_file(file_path) + results[file_name] = {'Accuracy': metrics['Accuracy']} + elif 'fold_type' in file_path: + test_files = [ + 'compute_scores/remote_homology_test_fold_holdout.json', + ('compute_scores/' + 'remote_homology_test_superfamily_holdout.json'), + 'compute_scores/remote_homology_test_family_holdout.json' + ] + acc_dict = compute_accuracy4fold_type(file_path, test_files) + results[file_name] = acc_dict + else: + metrics, _ = process_json_file(file_path) + results[file_name] = { + # 'Accuracy': metrics['Accuracy'], + 'Precision': metrics['Precision'], + 'Recall': metrics['Recall'], + 'F1 Score': metrics['F1 Score'] + } + print(results[file_name]) + with open(f'{eval_res_path}/metrics_result.json', 'w') as result_file: + json.dump(results, result_file, indent=4) + + print(f'Results saved to: {eval_res_path}/metrics_result.json') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--indir', + required=True, + help='Path to the result file dir') + args = parser.parse_args() + + main(args.indir) diff --git a/opencompass/datasets/SciReasoner/uncond_RNA.py b/opencompass/datasets/SciReasoner/uncond_RNA.py new file mode 100644 index 000000000..8c798ba1c --- /dev/null +++ b/opencompass/datasets/SciReasoner/uncond_RNA.py @@ -0,0 +1,128 @@ +import os +import re +import subprocess +from tempfile import TemporaryDirectory +from typing import Union + +from datasets import Dataset + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS + + +@LOAD_DATASET.register_module() +class Uncond_RNA_Dataset(BaseDataset): + + @staticmethod + def load(num, prompt): + dataset = [{'input': prompt, 'output': ''} for _ in range(num)] + return Dataset.from_list(dataset) + + +@TEXT_POSTPROCESSORS.register_module() +def RNA_postprocessor(text: Union[str, None]) -> str: + if not text: + return '' + + text = text.replace('T', 'U').replace('t', 'u') + + match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + + return '' + + +class RNA_Evaluator(BaseEvaluator): + + def score(self, predictions, references): + invalid_count = 0 + overlength_count = 0 + valid_rnas = [] + valid_bases = set('AUCG') + avg_mfe = None + rfam_families = [] + + for idx, seq in enumerate(predictions): + seq = seq.strip().upper() + if not seq or any(base not in valid_bases for base in seq): + invalid_count += 1 + else: + valid_rnas.append((f'seq{idx}', seq)) + if len(seq) > 1024: + overlength_count += 1 + + with TemporaryDirectory() as tmpdir: + tmpdir = 'tmp' + fasta_path = os.path.join(tmpdir, 'valid_sequences.fasta') + with open(fasta_path, 'w') as f: + for seq_id, seq in valid_rnas: + f.write(f'>{seq_id}\n{seq}\n') + + mfe_file = self.run_rnafold(fasta_path, tmpdir) + mfe_values = self.parse_mfe(mfe_file) + avg_mfe = sum(mfe_values) / len(mfe_values) if mfe_values else None + + rfam_cm = 'Rfam/Rfam.cm' + rfam_clanin = 'Rfam/Rfam.clanin' + rfam_tblout = self.run_cmscan(fasta_path, tmpdir, rfam_cm, + rfam_clanin) + rfam_families = self.parse_unique_families(rfam_tblout) + + return { + 'total_samples': len(predictions), + 'invalid_prediction_count': invalid_count, + 'overlength_prediction_count': overlength_count, + 'valid_sequence_count': len(valid_rnas), + 'average_mfe': avg_mfe, + 'retrieved_rfam_family_count': len(rfam_families), + } + + def run_rnafold(self, input_fasta, output_dir): + output_file = os.path.join(output_dir, 'mfe_results.txt') + cmd = ( + f'cd {output_dir} && RNAfold < ' + f'{os.path.abspath(input_fasta)} > {os.path.basename(output_file)}' + ) + ret = subprocess.run(cmd, shell=True) + if ret.returncode != 0: + print(ret) + raise RuntimeError('RNAfold execution failed!') + return output_file + + def parse_mfe(self, output_file): + mfe_values = [] + with open(output_file) as f: + for line in f: + match = re.search(r'\s\(([-\d\.]+)\)\s*$', line.strip()) + if match: + mfe = float(match.group(1)) + mfe_values.append(mfe) + return mfe_values + + def run_cmscan(self, fasta_file, output_dir, rfam_cm, rfam_clanin): + tblout_path = os.path.join(output_dir, 'cmscan_results.tblout') + cmscan_cmd = [ + 'cmscan', '--rfam', '--cut_ga', '--nohmmonly', '--tblout', + tblout_path, '--fmt', '2', '--clanin', rfam_clanin, rfam_cm, + fasta_file + ] + result = subprocess.run(cmscan_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if result.returncode != 0: + raise RuntimeError(f'cmscan failed:\n{result.stderr.decode()}') + return tblout_path + + def parse_unique_families(self, tblout_file): + families = set() + with open(tblout_file, 'r') as f: + for line in f: + if line.startswith('#'): + continue + parts = line.strip().split() + if len(parts) > 0: + family_id = parts[0] + families.add(family_id) + return families diff --git a/opencompass/datasets/SciReasoner/uncond_material.py b/opencompass/datasets/SciReasoner/uncond_material.py new file mode 100644 index 000000000..10e7ae430 --- /dev/null +++ b/opencompass/datasets/SciReasoner/uncond_material.py @@ -0,0 +1,80 @@ +import re +from typing import Union + +from datasets import Dataset + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS + + +@LOAD_DATASET.register_module() +class Uncond_material_Dataset(BaseDataset): + + @staticmethod + def load(num, prompt): + dataset = [{'input': prompt, 'output': ''} for _ in range(num)] + return Dataset.from_list(dataset) + + +@TEXT_POSTPROCESSORS.register_module() +def material_postprocessor(text: Union[str, None]) -> str: + if not text: + return '' + + match = re.search(r'(.*?)', text, + re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + + return '' + + +class uncond_material_Evaluator(BaseEvaluator): + + def score(self, predictions): + total = len(predictions) + format_valid = 0 + smact_valid = 0 + from collections import Counter + + from smact.screening import smact_validity + for text in predictions: + + match = re.match( + r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)' + r'\s*(?:<|⟨)sg(?:>|⟩)\s*(?:<|⟨)sg(\d+)(?:>|⟩)', text.strip()) + if not match: + continue + + elements_str, sg_num = match.groups() + elements = elements_str.split() + counter = Counter(elements) + formula = '' + for el, cnt in sorted(counter.items()): + formula += el + if cnt > 1: + formula += str(cnt) + try: + if smact_validity(formula): + smact_valid += 1 + format_valid += 1 + except Exception: + continue + + smact_validity_ratio_in_format_valid = smact_valid / format_valid \ + if format_valid else 0 + smact_validity_ratio_in_all = smact_valid / total if total else 0 + + return { + 'total_samples': + total, + 'format_valid_count': + format_valid, + 'smact_valid_count': + smact_valid, + 'smact_validity_ratio_in_format_valid': + smact_validity_ratio_in_format_valid * 100, + 'smact_validity_ratio_in_all': + smact_validity_ratio_in_all * 100, + } diff --git a/opencompass/datasets/SciReasoner/unconditional_molecule_generation/UMG.py b/opencompass/datasets/SciReasoner/unconditional_molecule_generation/UMG.py new file mode 100644 index 000000000..851a9ab8d --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_molecule_generation/UMG.py @@ -0,0 +1,123 @@ +import re + +from datasets import Dataset, DatasetDict + +try: + from rdkit import Chem +except Exception: + Chem = None + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET + + +@LOAD_DATASET.register_module() +class UMG_Dataset(BaseDataset): + + @staticmethod + def load(max_cut=-1): + gen_inst = 'Generate a molecule with ' + + output_samples = [ + 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C', + 'c1ccccc1C(=O)O', 'CCO', + 'CC(=O)Oc1ccccc1C(=O)O', 'CCO' + ] + + train_data = [{ + 'input': gen_inst, + 'output': output, + } for output in output_samples] + + len_test_data = 800 + + if (max_cut != -1): + len_test_data = min(len_test_data, max_cut) + + test_data = [{ + 'input': gen_inst, + 'output': '' + } for i in range(len_test_data)] + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +class UMG_Evaluator(BaseEvaluator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def is_valid_smiles_rdkit(self, s): + """使用 RDKit 验证 SMILES 字符串""" + if not isinstance(s, str) or not s: + return False + # 如果字符串中已经包含HTML标签样的结构,则认为它不是一个纯SMILES串 + # 这是为了避免重复处理已经被脚本标记过的SMILES + if '<' in s or '>' in s: + return False + mol = Chem.MolFromSmiles( + s, sanitize=False) # sanitize=False 允许解析但可能化学上无效的SMILES + return mol is not None + + def extract_smiles_simple(self, text: str) -> str | None: + # match = re.search(r"⟨mol⟩([A-Za-z0-9()=#+@\\/\.-]+)⟨/mol⟩", text) + if '' not in text: + generic_pat = re.compile(r'(?= 4 and self.is_valid_smiles_rdkit( + candidate): + print('candidate', candidate) + return f' {candidate} ' + else: + return candidate + + text = generic_pat.sub(generic_replace, text) + match = re.search(r' ([A-Za-z0-9()=#+@\\/\.-]+) ', + text) + if match: + # 提取并打印出干净的结果 + clean_smiles = match.group(1) + return clean_smiles + else: + return text + + def score(self, predictions): + if not predictions: + return {'validity': 0.0, 'uniqueness': 0.0, 'valid_smiles': []} + valid_smiles = [] + for smiles in predictions: + # RDKit有时会收到None或者空字符串,这里做一下防护 + if not smiles or not isinstance(smiles, str): + continue + smiles = self.extract_smiles_simple(smiles) + # 核心步骤:使用RDKit检查SMILES是否有效 + mol = Chem.MolFromSmiles(smiles.strip()) # .strip()去除首尾空白 + if mol is not None: + valid_smiles.append(smiles) + + total_generated = len(predictions) + total_valid = len(valid_smiles) + + # 计算有效率 Validity = (有效SMILES数量 / 总生成SMILES数量) + validity = float(total_valid) / float( + total_generated) if total_generated > 0 else 0.0 + + # 计算独特性 Uniqueness = (独特的有效SMILES数量 / 总有效SMILES数量) + if total_valid > 0: + unique_valid_smiles = set(valid_smiles) + + uniqueness = float(len(unique_valid_smiles)) / float(total_valid) + else: + uniqueness = 0.0 + print('validity', validity) + print('uniquness', uniqueness) + return {'validity': validity, 'uniquness': uniqueness} diff --git a/opencompass/datasets/SciReasoner/unconditional_molecule_generation/__init__.py b/opencompass/datasets/SciReasoner/unconditional_molecule_generation/__init__.py new file mode 100644 index 000000000..a2077c8ad --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_molecule_generation/__init__.py @@ -0,0 +1 @@ +from .UMG import UMG_Dataset, UMG_Evaluator # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/UPG.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/UPG.py new file mode 100644 index 000000000..35aeef7cd --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/UPG.py @@ -0,0 +1,195 @@ +import re + +from datasets import Dataset, DatasetDict + +from opencompass.datasets.base import BaseDataset +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS + + +@LOAD_DATASET.register_module() +class UPGDataset(BaseDataset): + + @staticmethod + def load(tag_bool=True, max_cut=-1): + if tag_bool: + gen_inst = 'Generate a protein sequence with .' + else: + gen_inst = 'Generate a protein sequence.' + output_samples = [ + 'MGDVEKGKKIFIMKCSQCHTVEKGGKHKTGPNLHGLFGRKTGQAPGYSYTAANKNK' + 'GIIWGEDTLMEYLENPKKYIPGTKMIFVGIKKKEERADLIAYLKKATNE', + 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKL' + 'PVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFE' + 'GDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIED' + 'GSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAG' + 'ITLGMDELYK', + 'MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGY' + 'NTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRD' + 'PQGIRAWVAWRNRCQNRDVRQYVQGCGV', + 'MLEVKERIAQAKAEIPAPVELAPEEIERLLWKLGWRPVAYGSEEKARELDELYGHP' + 'FAQEHPKEGAAGPVLAAARGGLEEYGAVEWGWGLGEREWAAAGRVAADVVRRLDGEAREGTLPA' + 'EAEAFPALAAALEHHHHHH', + 'MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRR' + 'EAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN', + ] + + train_data = [{ + 'input': gen_inst, + 'output': output, + } for output in output_samples] + + len_test_data = 1000 + # len_test_data = 10 + + if (max_cut != -1): + len_test_data = min(len_test_data, max_cut) + + test_data = [{ + 'input': gen_inst, + 'output': '' + } for i in range(len_test_data)] + + dataset = DatasetDict({ + 'train': Dataset.from_list(train_data), + 'test': Dataset.from_list(test_data) + }) + return dataset + + +@TEXT_POSTPROCESSORS.register_module('UPG_postprocess') +def UPG_postprocess(text): + # Check if the input is a string; + # if not, return an empty string to improve robustness + if not isinstance(text, str): + return '' + + # re.findall() searches for all occurrences of the pattern in the string + # (.*?) is a non-greedy capture group, + # capturing everything between the two tags + # re.DOTALL flag makes '.' match any character, including newlines + matches = re.findall(r'(.*?)', text, re.DOTALL) + + if matches: + # If a match is found, take the last one + # and strip leading/trailing whitespace + s = matches[-1].strip() + # Remove ';' + s = s.replace(';', '') + # Remove spaces + s = s.replace(' ', '') + + def clean_prediction(seq: str) -> str: + valid = set('ACDEFGHIKLMNPQRSTVWY-' + ) # Valid uppercase amino acid characters + return ''.join([aa for aa in seq.upper() if aa in valid]) + + s = clean_prediction(s) + return s + else: + # If no match is found, return an empty string + return '' + + +class UPG_Evaluator(BaseEvaluator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _calculate_sequence_identity(self, seq1, seq2): + """ + Calculate sequence identity between two sequences. + This is a simplified implementation for sequences of equal length, + computed by direct position-wise comparison. + More accurate methods may require alignment algorithms + (e.g., Smith-Waterman). + """ + if len(seq1) != len(seq2) or not seq1: + # For unequal-length or empty sequences, treat identity as 0 + # or adopt a more complex alignment strategy if needed. + # Here we return 0 for simplicity. + return 0 + matches = sum(1 for a, b in zip(seq1, seq2) if a == b) + return matches / len(seq1) + + def score(self, predictions, references=None): + """ + Evaluate the generated protein sequences. + + Args: + predictions (list[str]): List of model-generated protein sequences. + references (list[str], optional): + Reference sequences; ignored here. + + Returns: + dict: Dictionary containing evaluation metrics. + """ + if not predictions: + return { + 'average_length': 0, + 'diversity': 0, + 'average_plddt': 0, + 'info': 'Input predictions list is empty.' + } + + ori_len = len(predictions) + + predictions = [pred for pred in predictions if len(pred) > 0] + predictions = [ + pred for pred in predictions if not (pred.strip() == '') + ] + valid_rate = len(predictions) / ori_len + + # --- 1. Compute Average Length --- + total_length = sum(len(seq) for seq in predictions) + avg_length = total_length / len(predictions) + + # --- 2. Compute Diversity --- + # Use a greedy clustering algorithm with 99% + # sequence identity threshold + clusters_representatives = [] + for seq in predictions: + is_in_existing_cluster = False + for representative in clusters_representatives: + # Note: This uses simplified equal-length identity calculation. + # For sequences of different lengths, + # use sequence alignment tools. + # As a simple strategy, compare only if lengths are close. + if abs( + len(seq) - len(representative) + ) < 20: # Only compare sequences with small length differences + if self._calculate_sequence_identity( + seq, + representative) >= 0.99: # 99% sequence identity + is_in_existing_cluster = True + break + if not is_in_existing_cluster: + clusters_representatives.append(seq) + + num_clusters = len(clusters_representatives) + diversity = num_clusters / len(predictions) + + # --- 3. Compute Average pLDDT --- + # Only compute for sequences shorter than 100 residues + plddt_scores = [] + sequences_for_plddt = [ + seq for seq in predictions if (len(seq) < 100 and len(seq) > 0) + ] + + for s in sequences_for_plddt: + print(s) + + if sequences_for_plddt: + from .omegafold.__main__ import main as plddt_main + plddt_scores = plddt_main(sequences_for_plddt) + avg_plddt = sum(plddt_scores) / len(plddt_scores) + else: + avg_plddt = 0.0 # If no sequences shorter than 100, set to 0 + + return { + 'num_length_less_100': len(sequences_for_plddt), + 'valid_rate': round(valid_rate, 4), + 'average_length': round(avg_length, 2), + 'diversity': round(diversity, 4), + 'average_plddt': round(avg_plddt, 2) + } diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/__init__.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/__init__.py new file mode 100644 index 000000000..bae422486 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/__init__.py @@ -0,0 +1 @@ +from .UPG import UPG_Evaluator, UPG_postprocess, UPGDataset # noqa: F401, F403 diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/main.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/main.py new file mode 100644 index 000000000..bc565a4a1 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/main.py @@ -0,0 +1,71 @@ +from omegafold.__main__ import main + +if __name__ == '__main__': + protein_list = [ + 'MSS', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQ' + 'CQDFTRSAAKTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDF' + 'GGYVGSIYQVWGRGTVGIVG', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MSS', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFT' + 'RSAAKTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAA' + 'KTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKT' + 'GTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTG' + 'TATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA' + 'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA' + 'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA' + 'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + 'MKTIIL', + 'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT' + 'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG', + ] + print(main(protein_list)) diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__init__.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__init__.py new file mode 100644 index 000000000..1a5a11bb2 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__init__.py @@ -0,0 +1,38 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +from .config import make_config # noqa: F401, F403 +from .model import OmegaFold # noqa: F401, F403 + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__main__.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__main__.py new file mode 100644 index 000000000..2b0fd17f1 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__main__.py @@ -0,0 +1,104 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +The main function to run the prediction +""" +# ============================================================================= +# Imports +# ============================================================================= +import gc +import logging +import sys +import time + +import torch + +from . import OmegaFold, make_config, pipeline + +# ============================================================================= +# Functions +# ============================================================================= + + +@torch.no_grad() +def main(protein_list): + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + args, state_dict, forward_config = pipeline.get_args() + # create the output directory + # os.makedirs(args.output_dir, exist_ok=True) + # get the model + logging.info('Constructing OmegaFold') + model = OmegaFold(make_config(args.model)) + if state_dict is None: + logging.warning('Inferencing without loading weight') + else: + if 'model' in state_dict: + state_dict = state_dict.pop('model') + model.load_state_dict(state_dict) + model.eval() + model.to(args.device) + + # logging.info(f"Reading {args.input_file}") + + pLDDT_list = [] + + for i, (input_data, save_path) in enumerate( + pipeline.list2inputs( + protein_list, + num_pseudo_msa=args.num_pseudo_msa, + output_dir='./', + device=args.device, + mask_rate=args.pseudo_msa_mask_rate, + num_cycle=args.num_cycle, + )): + logging.info(f'Predicting {i + 1}th chain') + logging.info( + f"{len(input_data[0]['p_msa'][0])} residues in this chain.") + ts = time.time() + try: + output = model(input_data, + predict_with_confidence=True, + fwd_cfg=forward_config) + except Exception as e: + logging.error( + f'Failed to generate {save_path} due to an exception: {e}', + exc_info=True) + logging.info('Skipping...') + # 即使这里捕获了,output 仍然是 None,下面的检查会处理 + continue + logging.info(f'Finished prediction in {time.time() - ts:.2f} seconds.') + + # logging.info(f"Saving prediction to {save_path}") + + # print(output['confidence'] * 100) + + pLDDT_list.append(output['confidence'].mean().item() * 100) + + # pipeline.save_pdb( + # pos14=output["final_atom_positions"], + # b_factors=output["confidence"] * 100, + # sequence=input_data[0]["p_msa"][0], + # mask=input_data[0]["p_msa_mask"][0], + # save_path=save_path, + # model=0 + # ) + # logging.info(f"Saved") + del output + torch.cuda.empty_cache() + gc.collect() + logging.info('Done!') + + return pLDDT_list diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/confidence.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/confidence.py new file mode 100644 index 000000000..4cf381b69 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/confidence.py @@ -0,0 +1,152 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +Code for confidence-relevant things +""" + +# ============================================================================= +# Imports +# ============================================================================= +import argparse + +import torch +from torch import nn + +from . import modules, utils + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= + + +def get_all_confidence( + lddt_per_residue: torch.Tensor, + ca_coordinates: torch.Tensor, + ca_mask: torch.Tensor, + cutoff=15., +) -> float: + """ + Compute an approximate LDDT score for the entire sequence + + lDDT reference: + Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local + superposition-free score for comparing protein structures and models using + distance difference tests. Bioinformatics 29, 2722–2728 (2013). + + Code below adopted from + https://github.com/deepmind/alphafold/blob/1109480e6f38d71b3b265a4a25039e51e2343368/alphafold/model/lddt.py#L19 + + Args: + lddt_per_residue: the lddt score for each of the residues, + of shape [num_res] + ca_coordinates: the c-a coordinates of the residues, + of shape [num_res, 3] + ca_mask: mask of the c-a atoms, + of shape [num_res] + cutoff: The cutoff for each residue pair to be included + + Returns: + The overall confidence for the entire prediction + + """ + + assert ca_coordinates.ndim == 2 + assert lddt_per_residue.ndim == 1 + + # Compute true and predicted distance matrices. + dmat_true = torch.sqrt( + torch.sum((ca_coordinates[:, None] - ca_coordinates[None, :])**2, + dim=-1).add(1e-10)) + + dists_to_score = ( + torch.lt(dmat_true, cutoff) * ca_mask[..., :, None] * + ca_mask[..., None, :] * + (1. - torch.eye(dmat_true.shape[1], device=ca_mask.device)) + # Exclude self-interaction. + ) + + # Normalize over the appropriate axes. + + score = ((lddt_per_residue * + (torch.sum(dists_to_score, dim=(-1, )).add(1e-10))).sum(-1) / + (1e-10 + torch.sum(dists_to_score, dim=(-1, -2)))) + + return score.item() + + +def _compute_confidence(logits: torch.Tensor) -> torch.Tensor: + """ + Computes per-residue pLDDT from logits. + + Code below adopted from + https://github.com/deepmind/alphafold/blob/0be2b30b98f0da7aecb973bde04758fae67eb913/alphafold/common/confidence.py#L22 + + Args: + logits: the logits into the softmax, of shape [num_res, num_bins] + + Returns: + predicted_lddt_ca: the predicted CA lddt, of shape [num_res] + + """ + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = torch.arange(start=0.5 * bin_width, + end=1.0, + step=bin_width, + device=logits.device) + probs = torch.softmax(logits, dim=-1) + confidence = torch.mv(probs, bin_centers) + return confidence + + +# ============================================================================= +# Classes +# ============================================================================= + + +class ConfidenceHead(modules.OFModule): + """ + This is the same pLDDT head from AF2, which provides a confidence measure + of the model's prediction + + """ + + def __init__(self, cfg: argparse.Namespace): + super().__init__(cfg) + self.network = nn.Sequential( + nn.Linear(cfg.node_dim, cfg.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(cfg.hidden_dim, cfg.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(cfg.hidden_dim, cfg.num_bins), + ) + + def forward(self, node_repr: torch.Tensor) -> torch.Tensor: + node_repr = utils.normalize(node_repr) + logits = self.network(node_repr) + logits = _compute_confidence(logits) + + return logits + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/config.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/config.py new file mode 100644 index 000000000..bb20ced9c --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/config.py @@ -0,0 +1,118 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +Static configuration reside in this file +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse + + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +def _make_config(input_dict: dict) -> argparse.Namespace: + """Recursively go through dictionary""" + new_dict = {} + for k, v in input_dict.items(): + if type(v) == dict: + new_dict[k] = _make_config(v) + else: + new_dict[k] = v + return argparse.Namespace(**new_dict) + + +def make_config(model_idx: int = 1) -> argparse.Namespace: + if model_idx not in [1, 2]: + raise ValueError('model_idx must be 1 or 2') + cfg = dict(alphabet_size=21, + plm=dict( + alphabet_size=23, + node=1280, + padding_idx=21, + edge=66, + proj_dim=1280 * 2, + attn_dim=256, + num_head=1, + num_relpos=129, + masked_ratio=0.12, + ), + node_dim=256, + edge_dim=128, + relpos_len=32, + prev_pos=dict( + first_break=3.25, + last_break=20.75, + num_bins=16, + ignore_index=0, + ), + rough_dist_bin=dict( + x_min=3.25, + x_max=20.75, + x_bins=16, + ), + dist_bin=dict( + x_bins=64, + x_min=2, + x_max=65, + ), + pos_bin=dict( + x_bins=64, + x_min=-32, + x_max=32, + ), + c=16, + geo_num_blocks=50, + gating=True, + attn_c=32, + attn_n_head=8, + transition_multiplier=4, + activation='ReLU', + opm_dim=32, + geom_count=2, + geom_c=32, + geom_head=4, + struct=dict( + node_dim=384, + edge_dim=128, + num_cycle=8, + num_transition=3, + num_head=12, + num_point_qk=4, + num_point_v=8, + num_scalar_qk=16, + num_scalar_v=16, + num_channel=128, + num_residual_block=2, + hidden_dim=128, + num_bins=50, + )) + cfg['struct_embedder'] = model_idx == 2 + return _make_config(cfg) + + +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/decode.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/decode.py new file mode 100644 index 000000000..77f5142ee --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/decode.py @@ -0,0 +1,371 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +For generating the final coordinates of the amino acids of the predicted +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import math +import typing + +import torch +from torch import nn + +from . import modules, utils + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= + + +class InvariantPointAttention(modules.OFModule): + """ + This is the Invariant Point Attention from Jumper et al. (2021) that + performs transformer-like operation on frames + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(InvariantPointAttention, self).__init__(cfg) + node_dim = cfg.node_dim + edge_dim = cfg.edge_dim + num_head = cfg.num_head + num_scalar_qk = cfg.num_scalar_qk + num_point_qk = cfg.num_point_qk + num_scalar_v = cfg.num_scalar_v + num_point_v = cfg.num_point_v + + # For scalar parts + self.q_scalar = nn.Linear(node_dim, num_head * num_scalar_qk) + self.k_scalar = nn.Linear(node_dim, num_head * num_scalar_qk) + self.v_scalar = nn.Linear(node_dim, num_head * num_scalar_v) + + # to reason about the spatial relationships + self.q_point = nn.Linear(node_dim, num_head * 3 * num_point_qk) + self.k_point = nn.Linear(node_dim, num_head * 3 * num_point_qk) + self.v_point = nn.Linear(node_dim, num_head * 3 * num_point_v) + + # trainable weights for edge bias + self.trainable_point_weights = nn.Parameter( + torch.full([cfg.num_head], + fill_value=math.log(math.exp(1.) - 1)), ) + self.bias_2d = nn.Linear(edge_dim, num_head) + + final_input_dim = edge_dim + num_scalar_v + num_point_v * 4 + final_input_dim *= num_head + # output projection + self.output_projection = nn.Linear(final_input_dim, node_dim) + self.softplus = torch.nn.Softplus() + + # weighting of each component + num_logit_terms = 3 + scalar_variance = max(num_scalar_qk, 1) * 1. + point_variance = max(num_point_qk, 1) * 9. / 2 + self.scalar_weight = math.sqrt(1 / (num_logit_terms * scalar_variance)) + self.point_weight = math.sqrt(1 / (num_logit_terms * point_variance)) + self.edge_logits_weight = math.sqrt(1 / num_logit_terms) + + def forward(self, node_repr: torch.Tensor, edge_repr: torch.Tensor, + frames: utils.AAFrame) -> torch.Tensor: + """ + From Jumper et al. (2021), Invariant Point Attention + + Args: + node_repr: the node representation, + of shape [num_res, dim_node] + edge_repr: the edge representation, + of shape [num_res, num_res, dim_edge] + frames: the backbone frames of the amino acids, + of shape [num_res] + + Returns: + the node representation update of shape [num_res, dim_node] + + """ + n_head = self.cfg.num_head + + # acquire the scalar part of the attention logits + _q_scalar = self._get_scalar(self.q_scalar, node_repr, n_head) + _k_scalar = self._get_scalar(self.k_scalar, node_repr, n_head) + _v_scalar = self._get_scalar(self.v_scalar, node_repr, n_head) + scalar_logits = torch.einsum('qhc,khc->qkh', _q_scalar, _k_scalar) + scalar_logits *= self.scalar_weight + + # acquire the 2-dimensional bias from the edge representation + edge_logits = self.bias_2d(edge_repr) * self.edge_logits_weight + + # acquire the spatial part of the logits from the frames + _q_point = self._get_point(self.q_point, node_repr, n_head, frames) + _k_point = self._get_point(self.k_point, node_repr, n_head, frames) + _v_point = self._get_point(self.v_point, node_repr, n_head, frames) + dist = ((_q_point[:, None, ...] - _k_point[None, ...])**2) + point_logits = dist.sum([-1, -2]) * self.point_weight + point_logits *= self.softplus(self.trainable_point_weights) / 2 + + # Combine them and take the softmax + logits = scalar_logits + edge_logits - point_logits + logits += utils.mask2bias(frames.mask[None, ..., None]) + attn_w = modules.softmax(logits, dim=-2, in_place=True) + + # get the output + ret_edge = torch.einsum('...qkh,...qkc->...qhc', attn_w, edge_repr) + ret_scalar = torch.einsum('...qkh,...khc->...qhc', attn_w, _v_scalar) + ret_point = torch.einsum('...qkh,...khpc->...qhpc', attn_w, _v_point) + ret_point = frames.position_in_frame(ret_point) + + feat = torch.cat([ + ret_scalar.flatten(start_dim=-2), + ret_point.flatten(start_dim=-3), + utils.get_norm(ret_point).flatten(start_dim=-2), + ret_edge.flatten(start_dim=-2), + ], + dim=-1) + + return self.output_projection(feat) + + @staticmethod + def _get_scalar(linear: nn.Linear, inputs: torch.Tensor, + num_head: int) -> torch.Tensor: + """ + Pass the input through linear and then perform reshaping for the + multi-headed attention + + Args: + linear: the linear module to pass the input into + inputs: the input tensor to the linear module + num_head: the number of heads + + Returns: + key, query, or value for the multi-headed attention, + [num_res, num_head, dim] + + """ + return linear(inputs).unflatten(dim=-1, sizes=[num_head, -1]) + + @staticmethod + def _get_point(linear: nn.Linear, inputs: torch.Tensor, n_head: int, + transformation: utils.AAFrame) -> torch.Tensor: + """ + Pass the input through the linear and perform reshaping for the + multi-headed attention, then transform the points by the transformation + + Args: + linear: the linear module to compute the local points + inputs: the inputs into the linear module, of shape + n_head: the number of head + transformation: the transformation to make local global + + Returns: + points in global frame, [num_res, n_head, -1, 3] + + """ + local_points = linear(inputs).unflatten(dim=-1, sizes=[n_head, -1, 3]) + global_points = transformation.transform(local_points) + return global_points + + +class TorsionAngleHead(modules.OFModule): + """ + Predict the torsion angles of each of the amino acids from + node representation following Jumper et al. (2021) + """ + + def __init__(self, cfg: argparse.Namespace): + super(TorsionAngleHead, self).__init__(cfg) + + self.input_projection = nn.ModuleList( + [nn.Linear(cfg.node_dim, cfg.num_channel) for _ in range(2)]) + + self.resblock1 = nn.ModuleList([ + nn.Linear(cfg.num_channel, cfg.num_channel) + for _ in range(cfg.num_residual_block) + ]) + self.resblock2 = nn.ModuleList([ + nn.Linear(cfg.num_channel, cfg.num_channel) + for _ in range(cfg.num_residual_block) + ]) + + self.unnormalized_angles = nn.Linear(cfg.num_channel, 14) + + def forward( + self, representations_list: typing.Sequence[torch.Tensor] + ) -> torch.Tensor: + """ + Predict side chains using multi-rigid representations. + + Args: + representations_list: A list of activations to + predict side chains from. + Returns: + The normalized sin-cos representation of the torsion angles + """ + act = 0. + for (x, layer) in zip(representations_list, self.input_projection): + act = layer(torch.relu(x)) + act + + for layer1, layer2 in zip(self.resblock1, self.resblock2): + old_act = act + act = layer1(torch.relu(act)) + act = layer2(torch.relu(act)) + act = old_act + act + + sin_cos_raw = self.unnormalized_angles(torch.relu(act)) + + sin_cos_raw = sin_cos_raw.unflatten(dim=-1, sizes=[7, 2]) + sin_cos_normalized = utils.robust_normalize(sin_cos_raw) + + return sin_cos_normalized + + +class StructureCycle(modules.OFModule): + """ + Each of the cycles from + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(StructureCycle, self).__init__(cfg) + self.ipa = InvariantPointAttention(cfg) + self.input_norm = nn.LayerNorm(cfg.node_dim) + self.transition = nn.ModuleList([ + nn.Linear(cfg.node_dim, cfg.node_dim) + for _ in range(cfg.num_transition) + ]) + self.update_norm = nn.LayerNorm(cfg.node_dim) + + self.affine_update = nn.Linear(cfg.node_dim, 6) + + def forward( + self, node_repr: torch.Tensor, edge_repr: torch.Tensor, + backbone_frames: utils.AAFrame + ) -> typing.Tuple[torch.Tensor, utils.AAFrame]: + """ + Perform one backbone update and node representation update + + Args: + node_repr: the node representation, + of shape [num_res, dim_node] + edge_repr: the edge representation, + of shape [num_res, dim_edge] + backbone_frames: the backbone frames of the amino acids, + of shape [num_res] + + Returns: + + """ + node_repr += self.ipa(node_repr, edge_repr, backbone_frames) + node_repr = self.input_norm(node_repr) + # Transition + input_repr = node_repr + for layer in self.transition: + node_repr = layer(node_repr) + if layer is not self.transition[-1]: + node_repr = torch.relu(node_repr) + + node_repr += input_repr # Shortcut residual connection + node_repr = self.update_norm(node_repr) + backbone_update = self.affine_update(node_repr) + frame_update = utils.AAFrame.from_tensor(backbone_update, unit='nano') + backbone_frames = backbone_frames * frame_update + + return node_repr, backbone_frames + + +class StructureModule(modules.OFModule): + """Jumper et al. (2021) Suppl. Alg. 20 'StructureModule'""" + + def __init__(self, cfg: argparse.Namespace): + super(StructureModule, self).__init__(cfg) + self.node_norm = nn.LayerNorm(cfg.node_dim) + self.edge_norm = nn.LayerNorm(cfg.edge_dim) + self.init_proj = nn.Linear(cfg.node_dim, cfg.node_dim) + + self.cycles = nn.ModuleList( + [StructureCycle(cfg) for _ in range(cfg.num_cycle)]) + + self.torsion_angle_pred = TorsionAngleHead(cfg) + + def forward( + self, node_repr: torch.Tensor, edge_repr: torch.Tensor, + fasta: torch.Tensor, mask: torch.Tensor + ) -> typing.Tuple[torch.Tensor, typing.Dict[str, typing.Union[ + utils.AAFrame, torch.Tensor]]]: + """ + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + Args: + node_repr: node representation tensor of shape [num_res, dim_node] + edge_repr: edge representation tensor of shape [num_res, dim_edge] + fasta: the tokenized sequence of the input protein sequence + mask + + Returns: + node_repr: The current node representation tensor for confidence + of shape [num_res, dim_node] + dictionary containing: + final_atom_positions: the final atom14 positions, + of shape [num_res, 14, 3] + final_atom_mask: the final atom14 mask, + of shape [num_res, 14] + + """ + node_repr = self.node_norm(node_repr) + edge_repr = self.edge_norm(edge_repr) + + init_node_repr = node_repr + node_repr = self.init_proj(node_repr) + # Initialize the initial frames with Black-hole Jumper et al. (2021) + backbone_frames = utils.AAFrame.default_init(*node_repr.shape[0:1], + unit='nano', + device=self.device, + mask=mask.bool()) + + for layer in self.cycles: + node_repr, backbone_frames = layer(node_repr, edge_repr, + backbone_frames) + + torsion_angles_sin_cos = self.torsion_angle_pred( + representations_list=[node_repr, init_node_repr], ) + + torsion_angles_mask = torch.ones_like(torsion_angles_sin_cos[..., 0], + dtype=torch.bool) + backbone_frames = backbone_frames.to_angstrom(in_place=False) + frames8 = backbone_frames.expand_w_torsion( + torsion_angles=torsion_angles_sin_cos, + torsion_angles_mask=torsion_angles_mask, + fasta=fasta) + pos14, mask14 = frames8.expanded_to_pos(fasta) + return node_repr, { + 'final_frames': frames8, + 'final_atom_positions': pos14, + 'final_atom_mask': mask14 + } + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/embedders.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/embedders.py new file mode 100644 index 000000000..a9a601ac1 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/embedders.py @@ -0,0 +1,384 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import typing + +import torch +from torch import nn + +from . import modules, utils +from .utils import residue_constants as rc + + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +def _get_pos(shape: torch.Size, device: torch.device, dtype: torch.dtype, + seq_dim: typing.Tuple[int, ...]) -> torch.Tensor: + """Get the position of the tokens given + + Args: + shape: the shape of the tensor to be applied with RoPE + device: the device on which the tensor reside + dtype: the datatype of the tensor + seq_dim: dimensions of the tensor that reference the sequence length + + Returns: + The position tensor of the shape from ~shape indexed by seq_dim + + """ + spatial_shape = [shape[i] for i in seq_dim] + total_len = 1 + for i in spatial_shape: + total_len *= i + position = torch.arange(total_len, dtype=dtype, device=device) + position = position.reshape(*spatial_shape) + + return position + + +def _apply_embed(inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, + seq_dim: typing.Tuple[int, ...]) -> torch.Tensor: + """Applies RoPE to ~inputs + + Args: + inputs: the tensor to which RoPE is applied, the dimensions indexed by + ~seq_dim indicates the spatial dimensions + sin: the sine tensor that constitutes parts of the RoPE, + of spatial shape + vector dimension + cos: the cosine tensor that constitutes parts of the RoPE, + of spatial shape + vector dimension + seq_dim: the dimensions indicating the spatial dimensions, + must be consecutive + + Returns: + tensor with RoPE applied. + + """ + gaps = [(seq_dim[i + 1] - seq_dim[i]) == 1 + for i in range(len(seq_dim) - 1)] + if len(gaps) > 0: + if not all(gaps): + raise ValueError(f'seq_dim must be consecutive, but got {seq_dim}') + + # Align dimensions of sine and cosine + seq_dim = sorted(seq_dim) + end = seq_dim[-1] + for _ in range(seq_dim[0]): + sin = sin.unsqueeze(0) + cos = cos.unsqueeze(0) + end += 1 + + for _ in range(end, inputs.ndim - 1): + sin = sin.unsqueeze(_) + cos = cos.unsqueeze(_) + + # Apply RoPE + x1, x2 = torch.split(inputs, inputs.shape[-1] // 2, dim=-1) + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + +# ============================================================================= +# Classes +# ============================================================================= +class EdgeEmbedder(modules.OFModule): + """ + Embed the input into node and edge representations + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(EdgeEmbedder, self).__init__(cfg) + + self.proj_i = nn.Embedding(cfg.alphabet_size, cfg.edge_dim) + self.proj_j = nn.Embedding(cfg.alphabet_size, cfg.edge_dim) + self.relpos = RelPosEmbedder(cfg.relpos_len * 2 + 1, cfg.edge_dim) + + def forward(self, fasta_sequence: torch.Tensor, + out: torch.Tensor) -> torch.Tensor: + out += self.proj_i(fasta_sequence).unsqueeze(-2) + out += self.proj_j(fasta_sequence).unsqueeze(-3) + out += self.relpos(fasta_sequence.size(-1)) + + return out + + +class RoPE(nn.Module): + """The RoPE module + + Attributes: + input_dim: the dimension of the input vectors. + + """ + + def __init__(self, input_dim: int) -> None: + super(RoPE, self).__init__() + if input_dim % 2 != 0: + raise ValueError( + f'Input dimension for RoPE must be a multiple of 2,' + f' but got {input_dim}') + self.input_dim = input_dim + self.half_size = input_dim // 2 + freq_seq = torch.arange(self.half_size, dtype=torch.float32) + freq_seq = -freq_seq.div(float(self.half_size)) + + self.register_buffer('inv_freq', + torch.pow(10000., freq_seq), + persistent=False) + + def forward(self, tensor: torch.Tensor, + seq_dim: typing.Union[int, tuple]) -> torch.Tensor: + """ + + Args: + tensor: the tensor to apply rope onto + seq_dim: the dimension that represents the sequence dimension + + Returns: + + """ + if isinstance(seq_dim, int): + seq_dim = [ + seq_dim, + ] + sin, cos = self._compute_sin_cos(tensor, seq_dim) + + return _apply_embed(tensor, sin, cos, seq_dim) + + def _compute_sin_cos( + self, tensor: torch.Tensor, seq_dim: typing.Tuple[int] + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Compute sine and cosine tensors + + Args: + tensor: the tensors to apply RoPE to + seq_dim: the dimension indices of the spatial dimensions + + Returns: + A tuple of tensors where the first one is the sine tensor + and the second one is the cosine tensor + + """ + position = _get_pos(tensor.shape, tensor.device, tensor.dtype, seq_dim) + sinusoid = torch.einsum('..., d->...d', position, self.inv_freq) + sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) + return sin, cos + + +class RelPosEmbedder(nn.Embedding): + """ + Compute the relative positional embedding, + this is the same algorithm in + Jumper et al. (2021) Suppl. Alg. 4 "relpos" + """ + + def forward(self, num_res: int) -> torch.Tensor: + """ + + Args: + num_res: number of residues in input sequence. + + Returns: + + """ + idx = torch.arange(num_res, device=next(self.parameters()).device) + one_side = self.num_embeddings // 2 + idx = (idx[None, :] - idx[:, None]).clamp(-one_side, one_side) + idx = idx + one_side + return super(RelPosEmbedder, self).forward(idx) # [num_res, dim] + + +class StructEmbedder(modules.OFModule): + """ + Encoder for pair wise atom distance without distance clamp + but a sublinear-function with ord encoder. + """ + + def __init__(self, cfg: argparse.Namespace): + super(StructEmbedder, self).__init__(cfg) + self.rough_dist_bin = modules.Val2ContBins(cfg.rough_dist_bin) + self.dist_bin = modules.Val2ContBins(cfg.dist_bin) + self.pos_bin = modules.Val2ContBins(cfg.pos_bin) + + self.aa_embedding = nn.Embedding(21 * 21, embedding_dim=cfg.c) + + frame_num = 8 + atom_num = 14 + + self.dist_bin_embedding = nn.Linear(cfg.dist_bin.x_bins, cfg.c) + self.rough_dist_bin_embedding = nn.Linear(cfg.rough_dist_bin.x_bins, + cfg.c) + + self.dist_bin_linear = nn.Linear(atom_num * atom_num * cfg.c, cfg.c) + self.rough_dist_bin_linear = nn.Linear(atom_num * atom_num * cfg.c, + cfg.c) + + self.pos_bin_embedding = nn.Linear(cfg.pos_bin.x_bins, cfg.c) + self.pos_linear = nn.Linear(frame_num * atom_num * 3 * cfg.c, cfg.c) + + self.linear_z_weights = nn.Parameter( + torch.zeros([cfg.c, cfg.c, cfg.edge_dim])) + self.linear_z_bias = nn.Parameter(torch.zeros([cfg.edge_dim])) + + def forward( + self, + fasta1: torch.Tensor, + fasta2: torch.Tensor, + pos14_a: torch.Tensor, + mask14_a: torch.Tensor, + pos14_b: torch.Tensor, + mask14_b: torch.Tensor, + frame8: utils.AAFrame, + ): + pairwise_fasta = fasta1.unsqueeze(-1) * 21 + fasta2.unsqueeze(-2) + d = torch.norm(pos14_b[None, :, None] - pos14_a[:, None, :, None], + p=2, + dim=-1, + keepdim=False) + d_mask = mask14_b[None, :, None] * mask14_a[:, None, :, None] + d_mask = d_mask.unsqueeze(-1) + local_mask = torch.mul(mask14_b[None, :, None], frame8.mask[:, None, :, + None]) + local_mask = local_mask.unsqueeze(-1) + + local_vec = frame8.unsqueeze(1).unsqueeze(-1).position_in_frame( + pos14_b[None, :, None, :]) + + return self._sharded_compute(pairwise_fasta, d, local_vec, d_mask, + local_mask) + + def _sharded_compute(self, pairwise_fasta: torch.Tensor, d: torch.Tensor, + local_vec: torch.Tensor, d_mask: torch.Tensor, + local_mask: torch.Tensor) -> torch.Tensor: + pairwise_fasta = self.aa_embedding(pairwise_fasta) + d1 = self.rough_dist_bin(d) + d2 = self.dist_bin(d) + d3 = self.pos_bin(local_vec) + + d1 = self.rough_dist_bin_embedding(d1) + d1 = d1 * d_mask + d1 = self.rough_dist_bin_linear(d1.flatten(start_dim=-3)) + + d2 = self.dist_bin_embedding(d2) + d2 = d2 * d_mask + d2 = self.dist_bin_linear(d2.flatten(start_dim=-3)) + + d3 = self.pos_bin_embedding(d3) + d3 = d3 * (local_mask.unsqueeze(-1)) + d3 = self.pos_linear(d3.flatten(start_dim=-4)) + + final_d = d1 + d2 + d3 # + d4 + O_ = torch.einsum('...sdi,...sdj->...sdij', pairwise_fasta, final_d) + Z = torch.einsum('...sdij,ijh->...sdh', O_, + self.linear_z_weights) + self.linear_z_bias + return Z + + +class PairStructEmbedder(StructEmbedder): + + def forward( + self, + fasta: torch.Tensor, + pos14: torch.Tensor, + pos14_mask: torch.Tensor, + frame8: utils.AAFrame, + ): + return super(PairStructEmbedder, self).forward(fasta1=fasta, + fasta2=fasta, + pos14_a=pos14, + pos14_b=pos14, + mask14_a=pos14_mask, + mask14_b=pos14_mask, + frame8=frame8) + + +class RecycleEmbedder(modules.OFModule): + """ + The recycle embedder from Jumper et al. (2021) + + """ + + def __init__(self, cfg: argparse.Namespace): + super(RecycleEmbedder, self).__init__(cfg) + + self.layernorm_node = nn.LayerNorm(cfg.node_dim) + self.layernorm_edge = nn.LayerNorm(cfg.edge_dim) + self.dgram = modules.Val2Bins(cfg.prev_pos) + self.prev_pos_embed = nn.Embedding( + cfg.prev_pos.num_bins, + cfg.edge_dim, + ) + if cfg.struct_embedder: + self.embed_struct = PairStructEmbedder(cfg) + + def forward( + self, + fasta: torch.Tensor, + prev_node: torch.Tensor, + prev_edge: torch.Tensor, + prev_x: torch.Tensor, + node_repr: torch.Tensor, + edge_repr: torch.Tensor, + atom14_mask: torch.Tensor, + prev_frames: utils.AAFrame, + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Recycle the last run + + Args: + fasta: + prev_node: node representations from the previous cycle + of shape [num_res, node_repr_dim] + prev_edge: edge representations from the previous cycle + of shape [num_res, num_res, edge_repr_dim] + prev_x: pseudo beta coordinates from the previous cycle. + of shape [num_res, 3] + node_repr: the node representation to put stuff in + edge_repr: the edge representation to put stuff in + atom14_mask: the mask for the 14 atoms + prev_frames: the frames from the previous cycle + + Returns: + + """ + atom_mask = rc.restype2atom_mask.to(self.device)[fasta] + prev_beta = utils.create_pseudo_beta(prev_x, atom_mask) + d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3)) + d = self.dgram(d) + node_repr[..., 0, :, :] = (node_repr[..., 0, :, :] + + self.layernorm_node(prev_node)) + edge_repr += self.prev_pos_embed(d) + edge_repr += self.layernorm_edge(prev_edge) + if self.cfg.struct_embedder: + edge_repr += self.embed_struct(fasta, prev_x, atom14_mask, + prev_frames) + + return node_repr, edge_repr + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/geoformer.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/geoformer.py new file mode 100644 index 000000000..2635568da --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/geoformer.py @@ -0,0 +1,174 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +The code for GeoFormer, the main trunk +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import typing + +import torch +from torch import nn + +from . import modules, utils + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= + + +class GeoFormerBlock(modules.OFModule): + """ + One iteration of GeoFormer + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(GeoFormerBlock, self).__init__(cfg) + self.attention_w_edge_bias = modules.AttentionWEdgeBias( + d_node=cfg.node_dim, + d_edge=cfg.edge_dim, + n_head=cfg.attn_n_head, + attn_gating=cfg.gating, + attn_c=cfg.attn_c) + self.column_attention = modules.Attention(q_dim=cfg.node_dim, + kv_dim=cfg.node_dim, + gating=cfg.gating, + n_head=cfg.attn_n_head, + c=cfg.attn_c, + out_dim=cfg.node_dim, + n_axis=1) + self.node_transition = modules.Transition(d=cfg.node_dim, + n=cfg.transition_multiplier, + activation=cfg.activation) + self.out_product = modules.Node2Edge(in_dim=cfg.node_dim, + out_dim=cfg.edge_dim, + proj_dim=cfg.opm_dim) + self.geometric_attention = nn.ModuleList([ + modules.GeometricAttention(d_edge=cfg.edge_dim, + n_axis=2, + c=cfg.geom_c, + n_head=cfg.geom_head) + for _ in range(cfg.geom_count) + ]) + self.edge_transition = modules.Transition(d=cfg.edge_dim, + n=cfg.transition_multiplier, + activation=cfg.activation) + + def forward( + self, + node_repr: torch.Tensor, + edge_repr: torch.Tensor, + mask: torch.Tensor, + *, + fwd_cfg: typing.Optional[argparse.Namespace] = None + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + + Args: + node_repr: + edge_repr: + mask + fwd_cfg: + + Returns: + + """ + node_repr += self.attention_w_edge_bias(node_repr, + edge_repr, + mask, + fwd_cfg=fwd_cfg) + node_repr = self._column_attention(node_repr, mask, fwd_cfg=fwd_cfg) + node_repr += self.node_transition(node_repr, + subbatch_size=fwd_cfg.subbatch_size) + + edge_repr += self.out_product(node_repr, mask) + for layer in self.geometric_attention: + edge_repr += layer(edge_repr, mask[..., 0, :], fwd_cfg=fwd_cfg) + + edge_repr += self.edge_transition(edge_repr, fwd_cfg.subbatch_size) + + return node_repr, edge_repr + + def _column_attention(self, node_repr, mask, fwd_cfg): + node_repr_col = utils.normalize( + node_repr.transpose(-2, -3).contiguous()) + node_repr_col = self.column_attention(node_repr_col, + node_repr_col, + bias=utils.mask2bias( + mask.T[..., None, None, :]), + fwd_cfg=fwd_cfg) + node_repr += node_repr_col.transpose(-2, -3) + return node_repr + + +class GeoFormer(modules.OFModule): + + def __init__(self, cfg: argparse.Namespace): + super(GeoFormer, self).__init__(cfg) + self.blocks = nn.ModuleList( + [GeoFormerBlock(cfg) for _ in range(cfg.geo_num_blocks)]) + self.node_final_proj = nn.Linear(cfg.node_dim, cfg.struct.node_dim) + + def forward( + self, + node_repr: torch.Tensor, + edge_repr: torch.Tensor, + mask: torch.Tensor, + *, + fwd_cfg: typing.Optional[argparse.Namespace] = None + ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + Args: + node_repr: the node representation from the + pretrained language model, of shape[num_res, dim] + edge_repr: the edge representation from the + pretrained language model, of shape[num_res, num_res, dim] + mask: the mask indicating the validity of the amino acid, + of [num_res]. + fwd_cfg + + Returns: + edge_repr: the edge representation used for recycling + node_repr: the node representation used for recycling + final_node: the node representation used for structure generation + + """ + + for block in self.blocks: + node_repr, edge_repr = block(node_repr, + edge_repr, + mask, + fwd_cfg=fwd_cfg) + + final_node = self.node_final_proj(node_repr) + return node_repr, edge_repr, final_node + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/model.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/model.py new file mode 100644 index 000000000..1b98de44b --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/model.py @@ -0,0 +1,248 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import typing + +import torch +from torch import nn + +from . import (confidence, decode, embedders, geoformer, modules, omegaplm, + utils) +from .utils import residue_constants as rc + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= + + +class OmegaFoldCycle(modules.OFModule): + + def __init__(self, cfg: argparse.Namespace) -> None: + super(OmegaFoldCycle, self).__init__(cfg) + + self.geoformer = geoformer.GeoFormer(cfg) + self.structure_module = decode.StructureModule(cfg.struct) + self.confidence_head = confidence.ConfidenceHead(cfg.struct) + + def forward( + self, + fasta: torch.Tensor, + mask: torch.Tensor, + node_repr: torch.Tensor, + edge_repr: torch.Tensor, + fwd_cfg: typing.Optional[argparse.Namespace], + ) -> typing.Tuple[typing.Dict[str, torch.Tensor], typing.Dict[ + str, typing.Union[torch.Tensor, utils.AAFrame]]]: + """ + The forward method for one iteration of OmegaFold + + Args: + fasta: the tokenized sequence of the protein, of shape, + of shape [num_res] + mask: If to ignore, of shape, + of shape [num_res] + node_repr: + of shape [num_res, node_repr_dim] + edge_repr: + of shape [num_res, node_repr, edge_repr_dim] + fwd_cfg: + + Returns: + ret: A dictionary containing: + confidence: the confidence score of the output protein structure + + """ + + prev_node, edge_repr, node_repr = self.geoformer(node_repr=node_repr, + edge_repr=edge_repr, + mask=mask, + fwd_cfg=fwd_cfg) + + node_repr, ret = self.structure_module( + node_repr=node_repr[..., 0, :, :], + edge_repr=edge_repr, + fasta=fasta, + mask=mask[..., 0, :], + ) + + ret['confidence'] = self.confidence_head(node_repr) + + prev_dict = { + 'prev_node': prev_node[..., 0, :, :], + 'prev_edge': edge_repr, + 'prev_x': ret['final_atom_positions'], + 'prev_frames': ret['final_frames'], + } + return ret, prev_dict + + +_INPUTS = typing.List[typing.Dict[typing.Union[str, int], typing.Any]] + + +class OmegaFold(modules.OFModule): + """ + The Entire OmegaFold model that comprises a pretrained Protein Language + Model, an encoder of the primary sequence, as well as a structure module + for decoding + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(OmegaFold, self).__init__(cfg) + self.omega_plm = omegaplm.OmegaPLM(cfg.plm) + self.plm_node_embedder = nn.Linear(cfg.plm.node, cfg.node_dim) + self.plm_edge_embedder = nn.Linear(cfg.plm.edge, cfg.edge_dim) + self.input_embedder = embedders.EdgeEmbedder(cfg) + self.recycle_embedder = embedders.RecycleEmbedder(cfg) + self.omega_fold_cycle = OmegaFoldCycle(cfg) + + def forward( + self, + inputs: _INPUTS, + predict_with_confidence: typing.Optional[bool] = True, + *, + fwd_cfg: typing.Optional[argparse.Namespace] = None + ) -> typing.Dict[str, typing.Union[torch.Tensor, float]]: + """ + The forward implementation of OmegaFold + + Args: + inputs: + predict_with_confidence: if to choose with confidence + fwd_cfg: forward configuration + + Returns: + + """ + # Preparation before entering the cycles + primary_sequence = inputs[0]['p_msa'][..., 0, :] + max_confidence = 0 + prev_dict = self.create_initial_prev_dict(len(primary_sequence)) + final_result = None + + # Start cycling + residx_atom14_mask = rc.restype_atom14_mask.to( + device=primary_sequence.device)[primary_sequence] + for cycle_data in inputs: + p_msa, p_msa_mask = cycle_data['p_msa'], cycle_data['p_msa_mask'] + fasta, mask = p_msa[..., 0, :], p_msa_mask[..., 0, :] + node_repr, edge_repr = self.deep_sequence_embed( + p_msa, p_msa_mask, fwd_cfg) + node_recycle, edge_repr = self.recycle_embedder( + fasta=fasta, + prev_node=prev_dict.pop('prev_node'), + prev_edge=prev_dict.pop('prev_edge'), + prev_x=prev_dict.pop('prev_x'), + node_repr=node_repr, + edge_repr=edge_repr, + atom14_mask=residx_atom14_mask, + prev_frames=prev_dict.pop('prev_frames')) + + result, prev_dict = self.omega_fold_cycle(fasta=fasta, + mask=p_msa_mask, + node_repr=node_repr, + edge_repr=edge_repr, + fwd_cfg=fwd_cfg) + + confidence_overall = confidence.get_all_confidence( + result['confidence'], + result['final_atom_positions'][..., 1, :], mask) + result['confidence_overall'] = confidence_overall + if predict_with_confidence: + if confidence_overall > max_confidence: + max_confidence = confidence_overall + final_result = result + else: + final_result = result + + return final_result + + def deep_sequence_embed( + self, + fasta: torch.Tensor, + mask: torch.Tensor, + fwd_cfg: typing.Optional[argparse.Namespace], + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Run the forward method of the pretrained-language model + + Args: + fasta: the fasta sequence + mask: the mask indicating the validity of the token + + Returns: + + """ + node_repr, edge_repr = self.omega_plm(fasta, mask, fwd_cfg=fwd_cfg) + # return node_plm, edge_plm + node_repr = self.plm_node_embedder( + utils.normalize(node_repr, in_place=True)) + edge_repr = edge_repr.permute(1, 2, 0) + edge_repr = self.plm_edge_embedder( + utils.normalize(edge_repr, in_place=True)) + edge_repr = self.input_embedder(fasta[..., 0, :], out=edge_repr) + + return node_repr, edge_repr + + def create_initial_prev_dict( + self, num_res: int) -> typing.Dict[str, torch.Tensor]: + """ + Generate 'previous' (filling with 0's) features for the model + + Args: + num_res: the number of residues + + Returns: + + """ + return { + 'prev_node': + torch.zeros([num_res, self.cfg.node_dim], + device=self.device, + dtype=torch.float), + 'prev_edge': + torch.zeros([num_res, num_res, self.cfg.edge_dim], + device=self.device, + dtype=torch.float), + 'prev_x': + torch.zeros([num_res, 14, 3], + device=self.device, + dtype=torch.float), + 'prev_frames': + utils.AAFrame.default_init(num_res, + 8, + unit='Angstrom', + device=self.device) + } + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/modules.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/modules.py new file mode 100644 index 000000000..5cbd1a0b8 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/modules.py @@ -0,0 +1,636 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import numbers +import typing + +import torch +from torch import nn + +from . import utils + + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +def softmax(x: torch.Tensor, + dim: int, + *, + dtype: typing.Optional[torch.dtype] = None, + in_place: bool = False) -> torch.Tensor: + """ + In-place or normal softmax + + Args: + x: the input tensor + dim: the dimension along which to perform the softmax + dtype: the data type + in_place: if to perform inplace + + Returns: + + """ + if in_place: + max_val = torch.max(x, dim=dim, keepdim=True)[0] + torch.sub(x, max_val, out=x) + torch.exp(x, out=x) + summed = torch.sum(x, dim=dim, keepdim=True) + x /= summed + return x + else: + return torch.softmax(input=x, dim=dim, dtype=dtype) + + +def _attention( + query: torch.Tensor, key: torch.Tensor, scale: torch.Tensor, + value: torch.Tensor, bias: torch.Tensor, return_edge: bool, + edge_reduction: str, edge_reduction_dim: int +) -> typing.Tuple[torch.Tensor, typing.Optional[torch.Tensor]]: + """Normal attention + + Args: + query: positive tensor of shape (*_q, dim_qk) + key: positive tensor of shape (*_k, dim_qk) + scale: the scaling of logits + value: tensor of shape (*_k, dim_v) + bias: the bias acting as either mask or relative positional encoding + return_edge: if to return the logits of attention + + Returns: + The aggregated tensor of shape (*_q, dim_v) + + """ + logits = torch.einsum('...id, ...jd -> ...ij', query * scale, key) + logits.add_(bias) + attn = softmax(logits, dim=-1, in_place=not return_edge) + out = torch.einsum('...ij, ...jd -> ...id', attn, value) + if return_edge: + attn = getattr(attn, edge_reduction)(dim=edge_reduction_dim) + return out, attn + else: + return out, None + + +def attention( + query: torch.Tensor, + key: torch.Tensor, + scale: typing.Union[torch.Tensor, float], + value: torch.Tensor, + bias: torch.Tensor, + subbatch_size: typing.Optional[int] = None, + *, + return_edge: bool = False, + edge_reduction: str = 'sum', + edge_reduction_dim: int = 0, +) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor]]: + """Computes attention with q, k , v + + Args: + query: positive tensor of shape (*_q, dim_qk) + key: positive tensor of shape (*_k, dim_qk) + scale: the scaling of logits + value: tensor of shape (*_k, dim_v) + bias: the bias acting as either mask or relative positional encoding + subbatch_size: the subbatch size to split the computation into + return_edge: if to return the logits + edge_reduction: + edge_reduction_dim: + + Returns: + The aggregated tensor of shape (*_q, dim_v) + + """ + q_length, k_length, v_dim = query.shape[-2], key.shape[-2], value.shape[-1] + subbatch_size = subbatch_size or q_length + + batch_shape = list(query.shape[:-2]) + factory_kwargs = nn.factory_kwargs({ + 'device': query.device, + 'dtype': query.dtype + }) + output = torch.empty(*batch_shape, q_length, v_dim, **factory_kwargs) + if return_edge: + batch_shape.pop(edge_reduction_dim + 2) + attns = torch.empty(*batch_shape, q_length, k_length, **factory_kwargs) + else: + attns = None + + for i, q_i in enumerate(query.split(subbatch_size, dim=-2)): + start, end = i * subbatch_size, (i + 1) * subbatch_size, + if bias.shape[-2] != q_length: + b_i = bias + else: + b_i = bias[..., start:end, :] + + res, attn = _attention(q_i, key, scale, value, b_i, return_edge, + edge_reduction, edge_reduction_dim) + output[..., start:end, :] = res + if return_edge: + attns[..., start:end, :] = attn + + return output, attns + + +# ============================================================================= +# Classes +# ============================================================================= + + +class OFModule(nn.Module): + """ + The OmegaFold modules + args: The arguments used for each of the modules + """ + + def __init__(self, cfg: typing.Optional[argparse.Namespace]) -> None: + super(OFModule, self).__init__() + self.cfg = cfg + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + +class Transition(OFModule): + + def __init__(self, d: int, n: int, activation: str) -> None: + super(Transition, self).__init__(None) + fc1 = nn.Linear(d, n * d) + fc2 = nn.Linear(n * d, d) + try: + act = getattr(nn, activation)(inplace=True) + except TypeError: + act = getattr(nn, activation)() + self.network = nn.Sequential(fc1, act, fc2) + + def forward(self, x: torch.Tensor, + subbatch_size: typing.Optional[int]) -> torch.Tensor: + subbatch_size = subbatch_size or x.shape[-2] + + out = torch.empty_like(x) + for i, x_i in enumerate(x.split(subbatch_size, dim=0)): + start, end = i * subbatch_size, (i + 1) * subbatch_size + x_i = utils.normalize(x_i) + out[start:end] = self.network(x_i) + return out + + +class MultiHeadedScaling(OFModule): + """ + Perform an element wise scale shift + + """ + + def __init__( + self, + shape: typing.Union[int, typing.List[int], torch.Size], + num_heads: int, + on_out_ready: typing.Optional[typing.Callable[[torch.Tensor], + torch.Tensor]], + dtype: typing.Optional[torch.dtype] = None, + ) -> None: + """ + + Args: + shape: the shape of the input dimensions + num_heads: the number of dimensions to squeeze to + dtype: the dtype of the parameters at generation + on_out_ready: the function called on exit + """ + super(MultiHeadedScaling, self).__init__(None) + factory_kwargs = nn.factory_kwargs({'dtype': dtype}) + if isinstance(shape, numbers.Integral): + shape = (shape, ) + shape = list(tuple(shape)) + self.unsqueeze_dim = -(len(shape) + 1) + shape.insert(0, num_heads) + self.shape = shape + self.split_dims = [1] * num_heads + self.weight = nn.Parameter(torch.empty(self.shape, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(self.shape, **factory_kwargs)) + self.call_on_out_ready = on_out_ready + + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> typing.List[torch.Tensor]: + """ + Element wise multiplication followed by addition + + Args: + x: the input tensor with the trailing dimensions following + ~self.shape + + Returns: + A output tensor of the same shape + + """ + x = x.unsqueeze(self.unsqueeze_dim) * self.weight + self.bias + positive_index = x.ndim + self.unsqueeze_dim + if self.call_on_out_ready is not None: + x = self.call_on_out_ready(x) + + x = x.split(self.split_dims, dim=positive_index) + + return [x_i.squeeze(positive_index) for x_i in x] + + def reset_parameters(self): + nn.init.normal_(self.weight, std=0.02) + nn.init.zeros_(self.bias) + + +class Val2ContBins(OFModule): + + def __init__( + self, + cfg: argparse.Namespace, + ): + super(Val2ContBins, self).__init__(cfg) + + x_bin_size = (cfg.x_max - cfg.x_min) / (cfg.x_bins - 2) + + self.register_buffer('x_offset', + torch.linspace(cfg.x_min - x_bin_size / 2, + cfg.x_max + x_bin_size / 2, + cfg.x_bins), + persistent=False) + self.coeff = -0.5 / ((x_bin_size * 0.2)**2) + # `*0.5`: makes it not too blurred + + def forward(self, dist_x): # (*) + x_offset_shape = [1] * len(dist_x.size()) + [len(self.x_offset)] + x = dist_x.unsqueeze(-1) - self.x_offset.view(*x_offset_shape) + x_norm = self.coeff * torch.pow(x, 2) + x_norm = x_norm - x_norm.max(-1, keepdim=True)[0] + logits = torch.softmax(x_norm, dim=-1) + + return logits + + +class Val2Bins(OFModule): + """ + Convert continuous values to bins + + Attributes: + breaks: the line space break + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(Val2Bins, self).__init__(cfg) + self.register_buffer('breaks', + torch.linspace(cfg.first_break, cfg.last_break, + cfg.num_bins - 1), + persistent=False) + + def forward(self, dist: torch.Tensor) -> torch.Tensor: + """ + + Args: + dist: distances in the euclidean space. + + Returns: + + """ + dist = dist.unsqueeze(-1) + dist_bin = torch.sum(torch.gt(dist, self.breaks), + dim=-1, + dtype=torch.long) + return dist_bin + + +class Node2Edge(OFModule): + """Communicate between tracks + + faster than OutProductMean mostly due to a better implementation + """ + + def __init__(self, in_dim: int, proj_dim: int, out_dim: int) -> None: + super(Node2Edge, self).__init__(None) + self.input_proj = nn.Linear(in_dim, proj_dim * 2) + self.proj_dim = proj_dim + self.out_weights = nn.Parameter( + torch.empty(proj_dim, proj_dim, out_dim)) + self.out_bias = nn.Parameter(torch.empty(out_dim)) + + def forward(self, node_repr: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + node_repr = utils.normalize(node_repr) + act = self.input_proj(node_repr) + mask = mask[..., None] + act = act * mask + norm = torch.einsum('...sid, ...sjd->...ijd', mask, mask) + + l, r = act.split(self.proj_dim, dim=-1) + # We found this implementation to work significantly faster + out = torch.einsum('...sid, def, ...sje-> ...ijf', l, self.out_weights, + r) + self.out_bias + out = out / (norm + 1e-3) + + return out + + +class Attention(OFModule): + """ + Widely used attention mechanism + + Attributes: + qg_weights (nn.Parameter): weight matrices for queries and gates + qg_bias (nn.Parameter): biases for queries and gates + kv_weights (nn.Parameter): weight matrices for queries and gates + kv_bias (nn.Linear): biases for keys and values + + o_weights (nn.Linear): the output weight matrix + o_bias (nn.Linear): the output bias + """ + + def __init__(self, q_dim: int, kv_dim: int, n_head: int, gating: bool, + c: int, out_dim: int, n_axis: int) -> None: + super(Attention, self).__init__(None) + self.c = c + self.n_head = n_head + self.gating = gating + self.q_dim = q_dim + self.n_axis = n_axis + + self.qg_weights = nn.Parameter( + torch.empty(q_dim, n_axis, n_head, (gating + 1) * c)) + self.kv_weights = nn.Parameter( + torch.empty(kv_dim, n_axis, n_head, 2 * c)) + self.qg_bias = nn.Parameter( + torch.empty(n_axis, n_head, 1, c * (1 + gating))) + self.kv_bias = nn.Parameter(torch.empty(n_axis, n_head, 1, c * 2)) + + self.o_weights = nn.Parameter(torch.empty(n_axis, n_head, c, out_dim)) + self.o_bias = nn.Parameter(torch.empty([out_dim, n_axis])) + + def forward( + self, + q_inputs: torch.Tensor, + kv_inputs: torch.Tensor, + bias: torch.Tensor, + *, + fwd_cfg: typing.Optional[argparse.Namespace] = None + ) -> typing.Union[typing.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Perform the standard multi-headed attention with added gating with some + biases + + Args: + q_inputs: the inputs to generate query vectors, + of shape (*, q_len, q_dim, (n_axis)) + kv_inputs: the inputs to generate key and value vectors, + of shape (*, kv_len, kv_dim, (n_axis)) + bias: the bias for the logits + of shape (*, n_head, q_len, kv_len) + fwd_cfg: if return logits + + Return: + output tensor (*, seq_len, o_dim, (n_axis)) + attention logits (Optional) (q_len, kv_len, num_head) + """ + + # Acquire the q, k, v tensors + to_unsqueeze = (q_inputs.shape[-1] != self.n_axis + and q_inputs.shape[-1] == self.q_dim) + if to_unsqueeze: + q_inputs = q_inputs.unsqueeze(-1) + kv_inputs = kv_inputs.unsqueeze(-1) + if bias is not None: + bias = bias.unsqueeze(-4) + + attn_out = self._get_attn_out(q_inputs, kv_inputs, fwd_cfg, bias) + + output = torch.einsum('...rhqc,rhco->...qor', attn_out, self.o_weights) + output += self.o_bias + + if to_unsqueeze: + output = output.squeeze(-1) + return output + + def _get_attn_out(self, q_inputs, kv_inputs, fwd_cfg, bias): + + qg = torch.einsum('...qar,arhc->...rhqc', q_inputs, self.qg_weights) + qg += self.qg_bias + q_out = qg.split(self.c, dim=-1) + q = q_out[0] + + kv = torch.einsum('...kar,arhc->...rhkc', kv_inputs, self.kv_weights) + kv += self.kv_bias + k, v = kv.split([self.c, self.c], dim=-1) + + # Attention + subbatch_size = (q.shape[-4] + if fwd_cfg is None else fwd_cfg.subbatch_size) + attn_out, _ = attention(query=q, + key=k, + value=v, + subbatch_size=subbatch_size, + bias=bias, + scale=self.c**(-0.5)) + # get the gating + if self.gating: + g = torch.sigmoid(q_out[1]) + attn_out *= g + + return attn_out + + +class AttentionWEdgeBias(OFModule): + + def __init__(self, d_node: int, d_edge: int, n_head: int, + attn_gating: bool, attn_c: int) -> None: + super(AttentionWEdgeBias, self).__init__(None) + self.proj_edge_bias = nn.Linear( + in_features=d_edge, + out_features=n_head # , bias=False + ) + self.attention = Attention(q_dim=d_node, + kv_dim=d_node, + n_head=n_head, + gating=attn_gating, + c=attn_c, + out_dim=d_node, + n_axis=1) + + def forward( + self, + node_repr: torch.Tensor, + edge_repr: torch.Tensor, + mask: torch.Tensor, + *, + fwd_cfg: typing.Optional[argparse.Namespace] = None + ) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]: + """ + + Args: + node_repr: + edge_repr: + mask: + fwd_cfg: + + Returns: + + """ + node_repr = utils.normalize(node_repr) + edge_repr = utils.normalize(edge_repr) + # check dim + edge_bias = self.proj_edge_bias(edge_repr).permute(2, 0, 1) + + edge_bias = edge_bias + utils.mask2bias(mask[..., None, None, :]) + attn_out = self.attention(node_repr, + node_repr, + bias=edge_bias, + fwd_cfg=fwd_cfg) + return attn_out + + +def _get_sharded_stacked(edge_repr: torch.Tensor, subbatch_size: int): + subbatch_size = subbatch_size or edge_repr.shape[-2] + idx = 0 + start, end = 0, subbatch_size + while start < edge_repr.shape[-2]: + yield start, end, torch.stack( + [edge_repr[start:end], + edge_repr.transpose(-2, -3)[start:end]], + dim=-1) + idx += 1 + start, end = idx * subbatch_size, (idx + 1) * subbatch_size + + +class GeometricAttention(OFModule): + """We have a lot of stuff here for GRAM reduction + + """ + + def __init__(self, d_edge: int, c: int, n_head: int, n_axis: int) -> None: + super(GeometricAttention, self).__init__(None) + self.d_edge = d_edge + self.n_axis = n_axis + self.n_head = n_head + self.linear_b_weights = nn.Parameter( + torch.empty([d_edge, n_axis, n_head])) + self.linear_b_bias = nn.Parameter(torch.empty([n_axis, n_head, 1, 1])) + + self.act_w = nn.Parameter(torch.empty([d_edge, n_axis, d_edge * 5])) + self.act_b = nn.Parameter(torch.empty([n_axis, d_edge * 5])) + + self.out_proj_w = nn.Parameter(torch.empty([n_axis, d_edge, d_edge])) + self.out_proj_b = nn.Parameter(torch.empty([n_axis, d_edge])) + self.glu = nn.GLU() + + self.attention = Attention(q_dim=d_edge, + kv_dim=d_edge, + n_head=n_head, + c=c, + gating=True, + out_dim=d_edge, + n_axis=n_axis) + + def _get_attended(self, edge_repr: torch.Tensor, mask: torch.Tensor, + fwd_cfg) -> torch.Tensor: + attended = torch.empty(*edge_repr.shape, + self.n_axis, + dtype=edge_repr.dtype, + device=edge_repr.device) + b = torch.zeros(self.n_axis, + self.n_head, + *edge_repr.shape[:2], + dtype=edge_repr.dtype, + device=edge_repr.device) + b += utils.mask2bias(mask) + for s, e, edge_r in _get_sharded_stacked( + edge_repr, subbatch_size=fwd_cfg.subbatch_size): + b[..., s:e, :] = torch.einsum( + '...qkcr,crh->...rhqk', edge_r, + self.linear_b_weights) + self.linear_b_bias + for s, e, edge_r in _get_sharded_stacked( + edge_repr, subbatch_size=fwd_cfg.subbatch_size): + attended[s:e] = self.attention(edge_r, edge_r, b, fwd_cfg=fwd_cfg) + return attended[..., 0] + attended[..., 1].transpose(-2, -3) + + def _get_gated(self, edge_repr: torch.Tensor, mask: torch.Tensor, fwd_cfg): + gated = torch.empty(*edge_repr.shape[:2], + self.n_axis, + self.d_edge, + device=edge_repr.device, + dtype=edge_repr.dtype) + for s_row, e_row, edge_row in _get_sharded_stacked( + edge_repr, subbatch_size=fwd_cfg.subbatch_size): + act_row = self._get_act_row(edge_row, mask[s_row:e_row]) + act_g = torch.sigmoid( + torch.einsum('...dr,drc->...rc', edge_row, self.act_w[ + ..., -self.d_edge:]) + self.act_b[..., -self.d_edge:]) + for s_col, e_col, edge_col, in _get_sharded_stacked( + edge_repr, subbatch_size=fwd_cfg.subbatch_size): + act_col = self._get_act_col(edge_col, mask[s_col:e_col]) + ab = torch.einsum('...ikrd,...jkrd->...ijrd', act_row, act_col) + ab = utils.normalize(ab.contiguous()) + gated[s_row:e_row, + s_col:e_col] = torch.einsum('...rd,rdc->...rc', ab, + self.out_proj_w) + gated[s_row:e_row, s_col:e_col].add_(self.out_proj_b) + gated[s_row:e_row, s_col:e_col] *= act_g[:, s_col:e_col] + + return gated.sum(-2) + + def _get_sliced_weight(self, weight: torch.Tensor, shift=0): + w = weight[..., :-self.d_edge].unflatten(-1, sizes=(4, -1)) + w = w[..., shift::2, :] + w = w.flatten(start_dim=-2) + return w + + def _get_act_row(self, edge_row: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + w = self._get_sliced_weight(self.act_w) + b = self._get_sliced_weight(self.act_b) + act = torch.einsum('...dr,drc->...rc', edge_row, w) + b + act = self.glu(act) * mask[..., None, None, None] + return act + + def _get_act_col(self, edge_row: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + w = self._get_sliced_weight(self.act_w, shift=1) + b = self._get_sliced_weight(self.act_b, shift=1) + act = torch.einsum('...dr,drc->...rc', edge_row, w) + b + act = self.glu(act) * mask[..., None, None, None] + return act + + def forward(self, edge_repr: torch.Tensor, mask: torch.Tensor, + fwd_cfg) -> torch.Tensor: + edge_repr = utils.normalize(edge_repr) + out = self._get_attended(edge_repr, mask, fwd_cfg) + out += self._get_gated(edge_repr, mask, fwd_cfg) + + return out + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/omegaplm.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/omegaplm.py new file mode 100644 index 000000000..c4f41b6af --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/omegaplm.py @@ -0,0 +1,233 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +import argparse +import math +import typing + +import torch +from torch import nn + +from . import embedders, modules, utils + + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +def _get_qk_scaling(num_res: torch.Tensor, attn_dim: int) -> torch.Tensor: + """ + https://kexue.fm/archives/8823 + + Args: + num_res: [num_chunks] + attn_dim + + Returns: + + """ + return num_res.clamp(min=4e-5).log() / (math.log(512) * attn_dim**0.5) + + +# ============================================================================= +# Classes +# ============================================================================= +class GatedAttentionUnit(modules.OFModule): + """ + + """ + + def __init__(self, cfg: argparse.Namespace): + super(GatedAttentionUnit, self).__init__(cfg) + self.cfg = cfg + self.gva_proj = nn.Sequential( + nn.Linear(cfg.node, cfg.proj_dim * 2 + cfg.attn_dim), nn.SiLU()) + self.multi_headed_scaling = modules.MultiHeadedScaling( + cfg.attn_dim, + num_heads=2, + on_out_ready=lambda x: self.rope(x, x.ndim - 3)) + self.rope = embedders.RoPE(cfg.attn_dim) + self.relpos = embedders.RelPosEmbedder(cfg.num_relpos, embedding_dim=1) + self.output_proj = nn.Linear(cfg.proj_dim, cfg.node) + + def forward( + self, node: torch.Tensor, scaling: torch.Tensor, bias: torch.Tensor, + fwd_cfg: typing.Optional[argparse.Namespace] + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + The forward method of this class + + Args: + node: the node representation + scaling: logits scaling + bias: + fwd_cfg: + + Returns: + + """ + cfg = self.cfg + # initial projection + gates, values, base = self.gva_proj(node).split( + [cfg.proj_dim, cfg.proj_dim, cfg.attn_dim], dim=-1) + queries, keys = self.multi_headed_scaling(base) + + node, edge = modules.attention( + query=queries, + key=keys, + scale=scaling, + value=values, + bias=bias + self.relpos(base.shape[-2])[..., 0], + subbatch_size=fwd_cfg.subbatch_size, + return_edge=True, + edge_reduction='sum', + edge_reduction_dim=-3, + ) + + # unflatten the values, base will be unflattened in self._forward + node = node * gates + node = self.output_proj(node) + return node, edge + + +class OmegaPLMLayer(modules.OFModule): + """One OmegaPLM Layer + + This layer baked the pre-layernorm configuration into the model + + Attributes: + gau: the underlying GAU layer containing most of the computations + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(OmegaPLMLayer, self).__init__(cfg) + self.gau = GatedAttentionUnit(cfg) + + def forward( + self, node: torch.Tensor, qk_scaling: torch.Tensor, bias: torch.Tensor, + fwd_cfg: typing.Optional[argparse.Namespace] + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Forward method for pre-layernorm + + One layer of OmegaPLM + + Args: + node: the node representation + qk_scaling: the scaling of logits before attention + bias: the bias for logits before attention + fwd_cfg + + Returns: + node and edge representation + + """ + shortcut, node = node, utils.normalize(node) + node, edge = self.gau(node, qk_scaling, bias, fwd_cfg) + node = node + shortcut + return node, edge + + +class OmegaPLM(modules.OFModule): + """Encoder GAU model + + This is the OmegaPLM model in Wu et al. 2022. + + Attributes: + input_embedding: This is an embedding layer + layers: the trunk of the network containing modified GAU layers + output_norm: an output normalization layer + + """ + + def __init__(self, cfg: argparse.Namespace) -> None: + super(OmegaPLM, self).__init__(cfg) + self.input_embedding = nn.Embedding(cfg.alphabet_size, + cfg.node, + padding_idx=cfg.padding_idx) + self.layers = nn.ModuleList( + [OmegaPLMLayer(cfg) for _ in range(cfg.edge)]) + self.output_norm = nn.LayerNorm(cfg.node) + + def forward( + self, tokens: torch.Tensor, mask: torch.Tensor, + fwd_cfg: typing.Optional[argparse.Namespace] + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Forward method + + Args: + tokens: A tensor of input tokens, + of shape [*, seq_len] + mask: mask indicating the validity of the tokens, + of shape [*, seq_len] + fwd_cfg + + Returns: + + """ + qk_scaling = _get_qk_scaling(mask.sum(-1), self.cfg.attn_dim) + qk_scaling = qk_scaling[..., None, None] + bias = utils.mask2bias(mask[..., None, :]) + + node = self.input_embedding(tokens) + node *= self._get_finetuning_scale(mask, tokens) + edges = torch.empty(len(self.layers), + mask.shape[-1], + mask.shape[-1], + dtype=node.dtype, + device=node.device) + for i, layer in enumerate(self.layers): + node, edges[i] = layer(node, qk_scaling, bias, fwd_cfg) + node = self.output_norm(node) + + # Taking the average + edges /= (mask.any(-1).sum() + 1e-5) + + return node, edges + + def _get_finetuning_scale(self, mask: torch.Tensor, + tokens: torch.Tensor) -> torch.Tensor: + """Token dropout scaling + + This computes the scaling from Rives et al. 2021 + + Args: + mask: the mask indicating the validity of the input sequence + + Returns: + + """ + un_masked_ratio_train = 1 - self.cfg.masked_ratio + src_lengths = mask.sum(-1) + mask_ratio_observed = tokens.eq(21).sum(-1).float() / src_lengths + mask_ratio_observed = torch.where( + mask_ratio_observed == 1., + torch.full_like(mask_ratio_observed, 0.99), mask_ratio_observed) + return un_masked_ratio_train / (1 - mask_ratio_observed)[:, None, None] + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/pipeline.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/pipeline.py new file mode 100644 index 000000000..898dd408c --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/pipeline.py @@ -0,0 +1,424 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +This file contains the utilities that we use for the entire inference pipeline +""" +# ============================================================================= +# Imports +# ============================================================================= +from __future__ import annotations + +import collections +import logging +import ntpath +import os +import os.path +import pathlib +import types +import typing + +import torch +from Bio import PDB as PDB +from Bio.PDB import StructureBuilder +from huggingface_hub import hf_hub_download +from torch import hub +from torch.backends import cuda, cudnn + +from . import utils +from .utils.protein_utils import residue_constants as rc + +try: + from torch.backends import mps # Compatibility with earlier versions + + _mps_is_available = mps.is_available +except ImportError: + + def _mps_is_available(): + return False + + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +def _set_precision(allow_tf32: bool) -> None: + """Set precision (mostly to do with tensorfloat32) + + This allows user to go to fp32 + + Args: + allow_tf32: if allowing + + Returns: + + """ + if int(torch.__version__.split('.')[1]) < 12: + cuda.matmul.allow_tf32 = allow_tf32 + cudnn.allow_tf32 = allow_tf32 + else: + precision = 'high' if allow_tf32 else 'highest' + torch.set_float32_matmul_precision(precision) + + +def path_leaf(path: str) -> str: + """ + Get the filename from the path + + Args: + path: the absolute or relative path to the file + + Returns: + the filename + + """ + head, tail = ntpath.split(path) + return tail or ntpath.basename(head) + + +def fasta2inputs( + fasta_path: str, + output_dir: typing.Optional[str] = None, + num_pseudo_msa: int = 15, + device: typing.Optional[torch.device] = torch.device('cpu'), + mask_rate: float = 0.12, + num_cycle: int = 10, + deterministic: bool = True +) -> typing.Generator[typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + str], None, None]: + """ + Load a fasta file and + + Args: + fasta_path: the path to the fasta files + output_dir: the path to the output directory + num_pseudo_msa: + device: the device to move + mask_rate: + num_cycle: + deterministic: + + Returns: + + """ + chain_ids: list[str] = [] + aastr: list[str] = [] + with open(fasta_path, 'r') as file: + lines = file.readlines() + name = False + for line in lines: + if len(line) == 0: + continue + if line.startswith('>') or line.startswith(':'): + name = True + chain_ids.append(line[1:].strip('\n')) + else: + if name: + aastr.append(line.strip('\n').upper()) + name = False + else: + aastr[-1] = aastr[-1] + line.strip('\n').upper() + + combined = sorted(list(zip(chain_ids, aastr)), key=lambda x: len(x[1])) + if output_dir is None: + parent = pathlib.Path(fasta_path).parent + folder_name = path_leaf(fasta_path).split('.')[0] + output_dir = os.path.join(parent, folder_name) + os.makedirs(output_dir, exist_ok=True) + try: + name_max = os.pathconf(output_dir, 'PC_NAME_MAX') - 4 + except AttributeError: + # os.pathconf is UNIX specific. Set to 32 for now. + name_max = 32 + + for i, (ch, fas) in enumerate(combined): + fas = fas.replace('Z', 'E').replace('B', 'D').replace('U', 'C') + aatype = torch.LongTensor( + [rc.restypes_with_x.index(aa) if aa != '-' else 21 for aa in fas]) + mask = torch.ones_like(aatype).float() + assert torch.all(aatype.ge(0)) and torch.all(aatype.le(21)), \ + 'Only take 0-20 amino acids as inputs with unknown amino acid ' \ + 'indexed as 20' + if len(ch) < name_max: + out_fname = ch.replace(os.path.sep, '-') + else: + out_fname = f'{i}th chain' + out_fname = os.path.join(output_dir, out_fname + '.pdb') + + num_res = len(aatype) + data = list() + g = None + if deterministic: + g = torch.Generator() + g.manual_seed(num_res) + for _ in range(num_cycle): + p_msa = aatype[None, :].repeat(num_pseudo_msa, 1) + p_msa_mask = torch.rand([num_pseudo_msa, num_res], + generator=g).gt(mask_rate) + p_msa_mask = torch.cat((mask[None, :], p_msa_mask), dim=0) + p_msa = torch.cat((aatype[None, :], p_msa), dim=0) + p_msa[~p_msa_mask.bool()] = 21 + data.append({'p_msa': p_msa, 'p_msa_mask': p_msa_mask}) + + yield utils.recursive_to(data, device=device), out_fname + + +# modify fasta2inputs to list2inputs +def list2inputs( + protein_list: typing.List[str], + output_dir: typing.Optional[str] = None, + num_pseudo_msa: int = 15, + device: typing.Optional[torch.device] = torch.device('cpu'), + mask_rate: float = 0.12, + num_cycle: int = 10, + deterministic: bool = True +) -> typing.Generator[typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + str], None, None]: + """ + Load a fasta file and + + Args: + fasta_path: the path to the fasta files + output_dir: the path to the output directory + num_pseudo_msa: + device: the device to move + mask_rate: + num_cycle: + deterministic: + + Returns: + + """ + chain_ids: list[str] = [] + aastr: list[str] = [] + + chain_ids = [f'chain_{i}' for i in range(len(protein_list))] + aastr = protein_list + + combined = sorted(list(zip(chain_ids, aastr)), key=lambda x: len(x[1])) + name_max = 32 + + for i, (ch, fas) in enumerate(combined): + fas = fas.replace('Z', 'E').replace('B', 'D').replace('U', 'C') + aatype = torch.LongTensor( + [rc.restypes_with_x.index(aa) if aa != '-' else 21 for aa in fas]) + mask = torch.ones_like(aatype).float() + assert torch.all(aatype.ge(0)) and torch.all(aatype.le(21)), \ + 'Only take 0-20 amino acids as inputs with unknown amino acid ' \ + 'indexed as 20' + if len(ch) < name_max: + out_fname = ch.replace(os.path.sep, '-') + else: + out_fname = f'{i}th chain' + out_fname = os.path.join(output_dir, out_fname + '.pdb') + + num_res = len(aatype) + data = list() + g = None + if deterministic: + g = torch.Generator() + g.manual_seed(num_res) + for _ in range(num_cycle): + p_msa = aatype[None, :].repeat(num_pseudo_msa, 1) + p_msa_mask = torch.rand([num_pseudo_msa, num_res], + generator=g).gt(mask_rate) + p_msa_mask = torch.cat((mask[None, :], p_msa_mask), dim=0) + p_msa = torch.cat((aatype[None, :], p_msa), dim=0) + p_msa[~p_msa_mask.bool()] = 21 + data.append({'p_msa': p_msa, 'p_msa_mask': p_msa_mask}) + + yield utils.recursive_to(data, device=device), out_fname + + +def save_pdb(pos14: torch.Tensor, + b_factors: torch.Tensor, + sequence: torch.Tensor, + mask: torch.Tensor, + save_path: str, + model: int = 0, + init_chain: str = 'A') -> None: + """ + saves the pos14 as a pdb file + + Args: + pos14: the atom14 representation of the coordinates + b_factors: the b_factors of the amino acids + sequence: the amino acid of the pos14 + mask: the validity of the atoms + save_path: the path to save the pdb file + model: the model id of the pdb file + init_chain + + return: + the structure saved to ~save_path + + """ + builder = StructureBuilder.StructureBuilder() + builder.init_structure(0) + builder.init_model(model) + builder.init_chain(init_chain) + builder.init_seg(' ') + for i, (aa_idx, p_res, b, + m_res) in enumerate(zip(sequence, pos14, b_factors, mask.bool())): + if not m_res: + continue + aa_idx = aa_idx.item() + p_res = p_res.clone().detach().cpu() + if aa_idx == 21: + continue + try: + three = rc.residx_to_3(aa_idx) + except IndexError: + continue + builder.init_residue(three, ' ', int(i), icode=' ') + for j, (atom_name, ) in enumerate( + zip(rc.restype_name_to_atom14_names[three])): + if len(atom_name) > 0: + builder.init_atom(atom_name, + p_res[j].tolist(), + b.item(), + 1.0, + ' ', + atom_name.join([' ', ' ']), + element=atom_name[0]) + structure = builder.get_structure() + io = PDB.PDBIO() + io.set_structure(structure) + os.makedirs(pathlib.Path(save_path).parent, exist_ok=True) + io.save(save_path) + + +def _load_weights( + weights_url: str, + weights_file: str, +) -> collections.OrderedDict: + """ + Loads the weights from either a url or a local file. If from url, + + Args: + weights_url: a url for the weights + weights_file: a local file + + Returns: + state_dict: the state dict for the model + + """ + + weights_file = os.path.expanduser(weights_file) + use_cache = os.path.exists(weights_file) + if weights_file and weights_url and not use_cache: + logging.info( + f'Downloading weights from {weights_url} to {weights_file}') + os.makedirs(os.path.dirname(weights_file), exist_ok=True) + hub.download_url_to_file(weights_url, weights_file) + else: + logging.info(f'Loading weights from {weights_file}') + + return torch.load(weights_file, map_location='cpu') + + +def _get_device(device) -> str: + """ + Infer the accelerator + + Args: + device: the device type + + Returns: + + """ + if device is None: + if torch.cuda.is_available(): + return 'cuda' + elif _mps_is_available(): + return 'mps' + else: + return 'cpu' + elif device == 'cpu': + return device + elif device.startswith('cuda'): + if torch.cuda.is_available(): + return device + else: + raise ValueError('Device cuda is not available') + elif device == 'mps': + if _mps_is_available(): + return device + else: + raise ValueError('Device mps is not available') + else: + raise ValueError(f'Device type {device} is not available') + + +def get_args() -> typing.Tuple[types.SimpleNamespace, collections.OrderedDict, + types.SimpleNamespace]: + + # 直接构造 args 对象,代替 argparse.Namespace + args = types.SimpleNamespace() + args.num_cycle = 10 + args.subbatch_size = 448 + args.device = None + # TODO: Modify the path of weights_file + args.weights_file = hf_hub_download('SciReason/OmegaFold-release', + 'release2.pt', + repo_type='dataset') + args.weights = 'https://helixon.s3.amazonaws.com/release1.pt' + args.model = 2 + args.pseudo_msa_mask_rate = 0.12 + args.num_pseudo_msa = 15 + args.allow_tf32 = True + + _set_precision(args.allow_tf32) + + if args.model == 1: + weights_url = 'https://helixon.s3.amazonaws.com/release1.pt' + if args.weights_file is None: + args.weights_file = os.path.expanduser( + '~/.cache/omegafold_ckpt/model.pt') + elif args.model == 2: + weights_url = 'https://helixon.s3.amazonaws.com/release2.pt' + if args.weights_file is None: + args.weights_file = os.path.expanduser( + '~/.cache/omegafold_ckpt/model2.pt') + else: + raise ValueError( + f'Model {args.model} is not available, only 1 or 2 supported.') + + # 加载权重 + weights = _load_weights(weights_url, args.weights_file) + weights = weights.pop('model', weights) + + # 构造 forward_config + forward_config = types.SimpleNamespace( + subbatch_size=args.subbatch_size, + num_recycle=args.num_cycle, + ) + + # 自动设置设备 + args.device = _get_device(args.device) + + return args, weights, forward_config + + +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/__init__.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/__init__.py new file mode 100644 index 000000000..0558890a9 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/__init__.py @@ -0,0 +1,51 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# flake8: noqa +""" + +""" +# ============================================================================= +# Imports +# ============================================================================= +from typing import Dict, Union # noqa: F401, F403 + +import torch # noqa: F401, F403 + +from ..utils.protein_utils import residue_constants # noqa: F401, F403 +from ..utils.protein_utils.aaframe import AAFrame # noqa: F401, F403 +from ..utils.protein_utils.functions import bit_wise_not # noqa: F401, F403 +from ..utils.protein_utils.functions import \ + robust_normalize # noqa: F401, F403 +from ..utils.protein_utils.functions import create_pseudo_beta, get_norm +from ..utils.torch_utils import masked_mean # noqa: F401, F403 +from ..utils.torch_utils import normalize # noqa: F401, F403 +from ..utils.torch_utils import mask2bias, recursive_to + +# ============================================================================= +# Constants +# ============================================================================= +DATA = Dict[str, Union[str, bool, torch.Tensor, AAFrame]] +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/__init__.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/__init__.py new file mode 100644 index 000000000..631064c28 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/__init__.py @@ -0,0 +1,38 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" + +""" +from ...utils.protein_utils import residue_constants # noqa: F401, F403 +# ============================================================================= +# Imports +# ============================================================================= +from ...utils.protein_utils.aaframe import AAFrame # noqa: F401, F403 + +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/aaframe.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/aaframe.py new file mode 100644 index 000000000..0e08c5edc --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/aaframe.py @@ -0,0 +1,936 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +This script contains the Frame object, that acts as an essential part to +convert to full atom coordinates for amino acids. +This is inspired by Jumper et al. (2021), where the authors refer to this +object as rigid group/affine update, and we unify the two notions here. + +Some codes adopted from +https://github.com/deepmind/alphafold/blob/main/alphafold/model/all_atom.py +""" +# ============================================================================= +# Imports +# ============================================================================= +from typing import List, Tuple, Union + +import torch +from torch.nn import functional as F + +from ...utils.protein_utils import functions as f +from ...utils.protein_utils import residue_constants as rc + +# ============================================================================= +# Functions +# ============================================================================= +# ============================================================================= +# Constant +# ============================================================================= +_BACKBONE_ROTATE = torch.tensor([ + [-1, 0., 0.], + [0., 1., 0.], + [0., 0., -1], +]) + + +# ============================================================================= +# Classes +# ============================================================================= +class AAFrame(object): + """ + The transformation object that holds translation and rotation + """ + + def __init__(self, + translation: torch.Tensor = None, + rotation: torch.Tensor = None, + mask: Union[torch.Tensor, torch.Tensor] = None, + safe: bool = True, + unit: str = 'Angstrom', + *, + expanded: bool = False) -> None: + """ + Initialize the transformation + + Args: + translation (): the translation vector of shape (*, 3) + rotation (): the rotation vector of shape (*, 3, 3) + mask (): the torsion_angles_mask tensor indicating the presence of + the frame + safe (): if to use safe initialization, if unsafe, it"s faster + expanded (): if this frame is expanded to per-residue frames + """ + super(AAFrame, self).__init__() + self.orig = None + if safe: + self.mask = mask + self.translation = translation + self.rotation = rotation + else: + self._mask = mask + self._translation = translation + self._rotation = rotation + + self.expanded_ = expanded + self._unit = unit + + @property + def unit(self) -> str: + """ + Get the unit of the frame + + Returns: + the current unit of this frame + + """ + return self._unit + + def _assign(self, translation: torch.Tensor, rotation: torch.Tensor, + unit: str, mask: torch.Tensor, in_place: bool, + orig: str) -> 'AAFrame': + """ + Create a new one or in-place assignment + + Args: + translation: the translation (center) of the frame + rotation: the rotation of the frame + unit: the unit in which the frame operates + mask: the mask of the frames indicating which components are valid + in_place: if to perform the operation in-place + orig: the info of the origin of the new frame + + Returns: + A new frame, if not in-place, or the original frame with the + attributes + + """ + if in_place: + self._translation, self._rotation, = translation, rotation + self._unit, self._mask = unit, mask + return self + else: + return self._construct_frame(translation, + rotation, + mask, + orig=orig, + safe=True, + unit=unit) + + def to_nanometers(self, in_place: bool = True) -> 'AAFrame': + """ + Move the nanometers + + Args: + in_place: if to perform the operation in place. + + Returns: + + """ + if self._unit == 'Angstrom': + _translation = self._translation / 10 + else: + _translation = self._translation + _unit = 'nano' + _rotation = self._rotation + _mask = self._mask + return self._assign(translation=_translation, + rotation=_rotation, + unit=_unit, + mask=_mask, + orig=f'To nano from {self}', + in_place=in_place) + + def to_angstrom(self, in_place: bool) -> 'AAFrame': + """ + move to angstrom + + Args: + in_place: if to use in_place operation + + Returns: + + """ + if self._unit == 'nano': + _translation = self._translation * 10 + else: + _translation = self._translation + _unit = 'Angstrom' + _rotation = self._rotation + _mask = self._mask + return self._assign(translation=_translation, + rotation=_rotation, + unit=_unit, + mask=_mask, + orig=f'To nano from {self}', + in_place=in_place) + + @property + def translation(self) -> torch.Tensor: + """ + Mask the ~self._translation by self.mask + + Returns: + + """ + return self._translation + + @translation.setter + def translation(self, value: torch.Tensor) -> None: + """ + Assign the translation in the frame with masked values set to 0"s. + + Args: + value: the translation value + + """ + m = f.bit_wise_not(self.mask.unsqueeze(-1).expand_as(value)) + self._translation = value.masked_fill(m, 0) + + @property + def rotation(self) -> torch.Tensor: + """ + The rotation matrix + + Returns: + + """ + return self._rotation + + @rotation.setter + def rotation(self, value: torch.Tensor) -> None: + """ + Assign the rotation in the frame with masked values set to identity + matrices. + + Args: + value: the rotational matrices + + """ + mask = f.bit_wise_not(self.mask[..., None, None].expand_as(value)) + value = value.masked_fill(mask, 0.) + value = value.masked_fill( + mask * torch.eye(3, dtype=torch.bool).to(mask.device), 1) + self._rotation = value + + @property + def mask(self) -> torch.Tensor: + """ + Hope this protects the attribute + + Returns: + + """ + return self._mask + + @mask.setter + def mask(self, value: torch.Tensor): + self._mask = value.bool() + + @classmethod + def default_init( + cls, + *shape, + unit: str = 'Angstrom', + safe: bool = True, + device: torch.device = torch.device('cpu'), + mask: Union[torch.Tensor, torch.Tensor] = None, + ) -> 'AAFrame': + """ + partially initialize a bunch of frames, for now only supports one + dimensional + + Args: + shape (): the shape of the frames, + mask (): the mask, if not provided, will be all true + device (): on which will the frame reside + safe (): if to safe init + unit (): the unit + + Returns: + + """ + if mask is not None: + assert tuple(mask.shape) == shape + translation = torch.zeros(list(shape) + [3], device=device) + rotation = torch.eye(3, dtype=translation.dtype, + device=device) * torch.ones(list(shape) + [1, 1], + device=device) + if mask is None: + mask = torch.ones_like(translation[..., 0], dtype=torch.bool) + + return cls._construct_frame(trans=translation, + rots=rotation, + mask=mask, + orig='partially initialized', + safe=safe, + unit=unit) + + @classmethod + def _neg_dim(cls, dim: int) -> Tuple[int, int, int]: + if dim < 0: + return dim, dim - 1, dim - 2 + else: + return dim, dim, dim + + def unsqueeze(self, dim: int) -> 'AAFrame': + """ + see torch.squeeze + + Args: + dim (): + + Returns: + + """ + return self.dim_apply(torch.unsqueeze, dim=dim) + + def sum(self, dim: int, keepdim: bool = False) -> 'AAFrame': + """ + see torch.sum + + Args: + dim (): + keepdim (): + + Returns: + + """ + dim0, dim1, dim2 = self._neg_dim(dim) + m = torch.sum(self.mask, dim=dim0, keepdim=keepdim) + t = torch.sum(self.translation, dim=dim1, keepdim=keepdim) + r = torch.sum(self.rotation, dim=dim2, keepdim=keepdim) + return self._construct_frame(t, + r, + m, + f'Created by {torch.sum} at dim {dim}', + safe=False, + unit=self.unit) # from self + + def dim_apply(self, func: callable, dim: int) -> 'AAFrame': + """ + Apply torch functionals to the translation and rotations + + Args: + func (): the functional to apply to + dim (): the dimension to which the function will be applied + + Returns: + + """ + dim0, dim1, dim2 = self._neg_dim(dim) + m = func(self.mask, dim0) + t = func(self.translation, dim1) + r = func(self.rotation, dim2) + u = self.unit + return self._construct_frame(t, + r, + m, + f'Created by {func} at dim {dim}', + safe=False, + unit=u) # from self + + @classmethod + def _construct_frame( + cls, + trans: torch.Tensor, + rots: torch.Tensor, + mask: Union[torch.Tensor, torch.Tensor], + orig: str, + safe: bool, + unit: str, + ) -> 'AAFrame': + """ + Construct a frame + + Args: + trans: the absolute position in the bigger frame + rots: the rotation of the frame + mask: the mask indicating the validity of the frame + orig: the message information about the origin of the frame + unit: the unit for initialize + safe: if use safe init + + Returns: + + """ + # assert t.shape[:-1] == r.shape[:-2] == m.shape + transformation = AAFrame(translation=trans, + rotation=rots, + mask=mask, + safe=safe, + unit=unit) + transformation.orig = orig + + return transformation + + @classmethod + def from_4x4(cls, m: torch.Tensor, mask: torch.Tensor, + unit: str) -> 'AAFrame': + """ + get the frames from 4x4 matrix + + Args: + m (): the transformation in homogeneous coordinates + should be of shape (*, 4, 4) + mask (): the masking tensor + unit (): + + Returns: + A transformation + + """ + + return cls._construct_frame(m[..., 0:3, 3], + m[..., 0:3, 0:3], + mask=mask, + orig='from matrix', + safe=True, + unit=unit) + + def transform(self, pos: torch.Tensor) -> torch.Tensor: + """ + Apply the transformation on the input coordinates + + Args: + pos (): the 3-D coordinates to transforms, + of shape (*, 3) + + Note: + if we are using batched dims, we simply assume that the + dimensions of pos can be split into three parts + 1. the batched_dims + 2. the ones to do the outer-product-like expansion + 3. the 3 xyz coordinate value + + Returns: + transformed coordinates of the same shape as the coordinates, + of shape (N, 3) + + Examples: + >>> frames = AAFrame( + ... translation=torch.zeros(10,3), + ... rotation=torch.eye(3)[None, ...].repeat(10, 1, 1), + ... mask=torch.ones(10, dtype=torch.bool) + ... ) + >>> frames.shape + torch.Size([10]) + >>> frames.transform(torch.randn(10, 3)).shape + torch.Size([10, 3]) # one-to-one + >>> frames.transform(torch.randn(10, 1, 3)).shape + torch.Size([10, 1, 3]) # it is still one-to-one + >>> frames.transform(torch.randn(1, 4, 3)).shape + torch.Size([10, 4, 3]) # this broadcasts to every pair, + # with the first dimension being the + # frames + >>> frames.transform(torch.randn(4, 1, 3)).shape + torch.Size([10, 1, 3]) + >>> frames = AAFrame( + ... translation=torch.zeros(10, 9, 3), + ... rotation=torch.eye(3)[None, ...].repeat(10, 9, 1, 1), + ... mask=torch.ones(10, 9, dtype=torch.bool) + ... ) + >>> frames.shape + torch.Size([10, 9]) + >>> frames.transform(torch.randn(10, 9, 3)).shape + torch.Size([10, 9, 3]) + >>> frames.transform(torch.randn(10, 1, 3)).shape + torch.Size([10, 9, 3]) # this broadcasts to 9, but does not + # work with shape (1, 9, 3) + >>> frames.transform(torch.randn(1, 1, 3)).shape + torch.Size([10, 9, 3]) # + >>> frames.transform(torch.randn(10, 9, 4, 3)).shape + torch.Size([10, 9, 4, 3]) + >>> frames.transform(torch.randn(10, 1, 9, 4, 3)).shape + torch.Size([10, 9, 9, 4, 3]) # the 1st, 2nd dim are from frames + >>> frames.transform(torch.randn(10, 1, 1, 3)).shape + torch.Size([10, 9, 1, 3]) + """ + batched_dims = len(self.shape) + shape1 = self.shape[:batched_dims] + shape2 = pos.shape[batched_dims:-1] # the ones to cross + self_shape2 = self.shape[batched_dims:] + out = self.view(*shape1, *[1 for _ in range(len(shape2))], + *self_shape2) + return f.batch_matrix_vector(out.rotation, pos) + out.translation + + @classmethod + def from_torsion( + cls, + unit: str, + torsion_angles: torch.Tensor, + mask: Union[torch.Tensor, torch.Tensor], + translation: torch.Tensor = None, + ) -> 'AAFrame': + """ + Create a transformation that rotates around the x-axis + + Args: + unit (): + torsion_angles (): the torsion angle to create the axis with, + should be of shape (*, 2) + mask (): the masking tensor + translation (): optional, if provided will be passed in to the + transformation + + Returns: + A rotation matrix around the x axis + + """ + device = torsion_angles.device + _make_rot_mat = torch.tensor( + [ + [0., 0., 0., 0., 0., -1, 0., 1., 0.], # sin + [0., 0., 0., 0., 1., 0., 0., 0., 1.], # cos + ], + dtype=torsion_angles.dtype, + device=device) + rot_mat = torch.matmul(torsion_angles, _make_rot_mat) + + rot_mat = rot_mat.unflatten(dim=-1, sizes=[3, 3]) + rot_mat[..., 0, 0] = 1 + + if translation is None: + shape = list(torsion_angles.shape) + shape[-1] = 3 + translation = torch.zeros(*shape, device=device) + + return cls._construct_frame(translation, + rot_mat, + mask, + 'from torsion', + safe=True, + unit=unit) + + def __getitem__(self, idx: Union[slice, int, torch.Tensor]) -> 'AAFrame': + """ + Select the frame + + Args: + idx (): the index of the selection + + Returns: + selected transformation + + """ + if isinstance(idx, (slice, int)): + return self._construct_frame(self.translation[..., idx, :], + self.rotation[..., idx, :, :], + self.mask[..., idx], + f'selected from {self} at {idx}', + unit=self.unit, + safe=False) + elif isinstance(idx, torch.Tensor): + return self._construct_frame(self.translation[idx, :], + self.rotation[idx, :, :], + self.mask[idx], + f'selected from {self} by tensor', + unit=self.unit, + safe=False) + else: + raise IndexError(f'Type {type(idx)} not supported for indexing') + + def __setitem__(self, key: Union[int, torch.Tensor, List[int]], + value: Union[torch.Tensor, 'AAFrame']) -> None: + if isinstance(value, AAFrame): + t = value.translation.to(self._translation.dtype) + r = value.rotation.to(self._rotation.dtype) + m = value.mask.to(self._mask.dtype) + else: + t = r = value + m = bool(value) + mask = self.mask.clone() + translation = self.translation.clone() + rotation = self.rotation.clone() + + if isinstance(key, int): + mask[..., key] = m + translation[..., key, :] = t + rotation[..., key, :, :] = r + elif isinstance(key, (torch.Tensor, list)): + # this because it cannot use in-place operations for gradients + mask[key] = m + translation[key, :] = t + rotation[key, :, :] = r + + self.mask = mask + self.translation = translation + self.rotation = rotation + + @property + def device(self) -> torch.device: + """ + + Returns: + + """ + assert (self._mask.device == self._translation.device == + self._rotation.device) + return self._mask.device + + @property + def shape(self) -> torch.Size: + """ + + Returns: the shape of the tensor + + """ + return self.mask.shape + + def __mul__(self, other) -> 'AAFrame': + if isinstance(other, AAFrame): + return self._combine_transformation(other) + else: + return self._tensor_multiplication(other) + + def _tensor_multiplication(self, other: torch.Tensor) -> 'AAFrame': + """ + Multiply everything by the tensor + + Args: + other: + + Returns: + + """ + if torch.logical_or(torch.eq(other, 0), torch.eq(other, 1)).all(): + m = self.mask * other + t = self.translation * other[..., None] + r = self.rotation * other[..., None, None] + else: + t = self.translation * other + m = self.mask + r = self.rotation + + return self._construct_frame(t, + r, + m, + f'Created by multiplication from {self}', + safe=False, + unit=self.unit) + + def _combine_transformation(self, other: 'AAFrame') -> 'AAFrame': + """ + Combine two frames + + Args: + The following two arguments all have the transition of shape + (N, 3) at the first place and rotation matrix of shape (N, 3, 3) at + the second + + other (): frame 1 + + Returns: + the end frame + + """ + # the rotation + if self.shape != other.shape: + t_1 = self.translation[..., None, :].expand_as(other.translation) + r_1 = self.rotation[..., None, :, :].expand_as(other.rotation) + m_1 = self.mask[..., None].expand_as(other.mask).reshape(-1) + t_1, r_1 = t_1.reshape(-1, 3), r_1.reshape(-1, 3, 3) + else: + t_1, r_1, m_1 = self.translation, self.rotation, self.mask + t_1, r_1, m_1 = t_1.view(-1, 3), r_1.view(-1, 3, 3), m_1.view(-1) + + if self.unit == 'Angstrom': + other.to_angstrom(in_place=True) + else: + other.to_nanometers(in_place=True) + t_2, r_2 = other.translation.view(-1, 3), other.rotation.view(-1, 3, 3) + m_2 = other.mask.view(-1) + + r_out = torch.bmm(r_1, r_2) + # the transition + t_out = t_1 + f.batch_matrix_vector(r_1, t_2) + # the torsion_angles_mask + m_out = m_1 * m_2 + + return self._construct_frame(t_out.view(*other.shape, 3), + r_out.view(*other.shape, 3, 3), + m_out.view(*other.shape), + f'Combination of {self} and {other}', + safe=False, + unit=self.unit) + + def __repr__(self) -> str: + return f'Frame {id(self)}' + + def view(self, *args) -> 'AAFrame': + """ + See Tensor.view + + Args: + *args (): + + Returns: + + """ + mask = self.mask + translation = self.translation + rotation = self.rotation + return self._construct_frame(translation.view(*args, 3), + rotation.view(*args, 3, 3), + mask.view(*args), + f'view from {self}', + safe=False, + unit=self.unit) + + @property + def dtype(self): + return self.translation.dtype + + def expand_w_torsion(self, torsion_angles: torch.Tensor, + torsion_angles_mask: torch.Tensor, + fasta: torch.Tensor) -> 'AAFrame': + r""" + Compute the global frame + + Lines 2-10 + Algorithm 24, Page 31 of the AlphaFold 2 supplementary material + + Args: + self (): the transformation from backbone to global + bb_coor (): the transition of the backbone transformation, or + the coordinates of the CA atom, + should be of shape (N, 3) + bb_rot (): the rotation of the backbone transformation, + should be of shape (N, 3, 3) + torsion_angles (): the torsion angles + (\omega, \phi, \psi, \chi_1, \chi_2, \chi_3, \chi_4) + should be of shape (N, 7, 2) + torsion_angles_mask (): the torsion angle masks indicating presence + (\omega, \phi, \psi, \chi_1, \chi_2, \chi_3, \chi_4) + should be of shape (N, 7) + fasta (): input sequence where each place is an index indicating + which amino acid is in each position, following ~restypes + + Returns: + Frame + + """ + assert self.unit == 'Angstrom' + if torsion_angles.shape[-2] == 5: + torsion_angles = torch.cat((torch.zeros_like( + torsion_angles[..., 0:2, :]), torsion_angles), + dim=-2) + torsion_angles_mask = torch.cat((torch.zeros_like( + torsion_angles_mask[..., 0:2]), torsion_angles_mask), + dim=-1) + + # append an identity for backbone2backbone + shape = list(torsion_angles.shape) + shape[-2] = 1 + angle = torch.tensor([[0, 1]], dtype=self.dtype, + device=self.device).expand(shape) # (*, 1, 2) + angle_mask = torch.tensor([True], dtype=torch.bool, + device=self.device).expand(shape[:-1]) + torsion_angles = torch.cat((angle, torsion_angles), -2) # (*, 8, 2) + torsion_angles_mask = torch.cat((angle_mask, torsion_angles_mask), -1) + + # prepare the angles + torsion_angles = f.robust_normalize(torsion_angles) + rot_x = AAFrame.from_torsion(torsion_angles=torsion_angles, + mask=torsion_angles_mask, + unit='Angstrom') + + # make extra backbone frames + # This follows the order of ~restypes + m = rc.restype_aa_default_frame.to(self.device)[fasta] + default_frames = AAFrame.from_4x4(m, + torsion_angles_mask, + unit='Angstrom') + all_frames = default_frames * rot_x + # make side chain frames (chain them up along the side chain) + chi2_frame_to_frame = all_frames[5] + chi3_frame_to_frame = all_frames[6] + chi4_frame_to_frame = all_frames[7] + # chains + chi1_frame_to_backb = all_frames[4] + chi2_frame_to_backb = chi1_frame_to_backb * chi2_frame_to_frame + chi3_frame_to_backb = chi2_frame_to_backb * chi3_frame_to_frame + chi4_frame_to_backb = chi3_frame_to_backb * chi4_frame_to_frame + + # all_frames[4] = chi1_f2bb + all_frames[5] = chi2_frame_to_backb + all_frames[6] = chi3_frame_to_backb + all_frames[7] = chi4_frame_to_backb + # get all + # map atom literature positions to the global frame + all_f2global = self * all_frames + all_f2global.expanded_ = True + + return all_f2global + + def rotate(self, rotation: torch.Tensor): + """ + Rotate with a rotation matrix + + Note: + batched rotated not yet supported, + for now just use ~Frame._construct_transformation + + Args: + rotation (): the rotation matrix of shape (d, d) + + Returns: + Rotated frame + + """ + if len(rotation.shape) == 2: + t = self.translation + r = torch.matmul(self.rotation, rotation) + return self._construct_frame(t, + r, + self.mask, + f'Rotated from {self}', + safe=False, + unit=self.unit) + else: + raise NotImplementedError('Not yet implemented') + + def expanded_to_pos( + self, + fasta: torch.Tensor, + full: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the full atom representation + + Args: + fasta: the sequence to compute the atoms + full: if to use safe initialization + + Returns: + the atom14 representation and the mask indicating the presence + of the atoms + + """ + if full: + assert self.expanded_ + num_classes = 8 + frame = self + pos_counts = 14 + else: + num_classes = 1 + frame = self.unsqueeze(-1) + pos_counts = 5 + + assert self._unit == 'Angstrom' + + fasta = fasta.cpu() + residx2group = rc.restype_atom14_to_aa + residx2group = residx2group[..., :pos_counts] + residx2group = residx2group[fasta].to(self.device) + group_mask = F.one_hot(residx2group, num_classes=8) + group_mask = group_mask[..., :num_classes] + group_mask = group_mask * frame.mask[..., None, :] + to_mask = frame.unsqueeze(-2) * group_mask + map_atoms_to_global = to_mask.sum(-1) + lit_pos = rc.restype_atom14_aa_positions + lit_pos = lit_pos[..., :pos_counts, :] + lit_pos = lit_pos[fasta].to(self.device) + pred_pos = map_atoms_to_global.transform(lit_pos) + # mask = c.restype_atom14_mask[sequence] # (N, 14) + # mask |= self.mask[..., None] + pred_pos = pred_pos * map_atoms_to_global.mask[..., None] + + return pred_pos, torsion_mask_to_atom14_mask(frame.mask, + group_mask, + fasta=fasta) + + def __len__(self): + return len(self.mask) + + @property + def inverse(self) -> 'AAFrame': + """ + The inverse of the transformation + + Returns: + + """ + r = self.rotation.transpose(-1, -2) + t = f.batch_matrix_vector(r, self.translation) + return self._construct_frame(-t, + r, + self.mask, + f'inversed from {self}', + safe=False, + unit=self.unit) + + def position_in_frame(self, pos: torch.Tensor) -> torch.Tensor: + """ + Get the frame-based position of the given global position + + Args: + pos (): the global position of shape (*, 3) + + Returns: + the result + + """ + return self.inverse.transform(pos) + + @classmethod + def from_tensor(cls, tensor, unit: str) -> 'AAFrame': + """ + Args: + tensor: (*, 7) + unit: + """ + q_dim = 4 if tensor.shape[-1] == 7 else 3 + quaternion, tx, ty, tz = torch.split(tensor, [q_dim, 1, 1, 1], dim=-1) + rotation = f.quaternion_to_matrix(quaternion) + translation = torch.stack([tx[..., 0], ty[..., 0], tz[..., 0]], dim=-1) + + return cls._construct_frame(trans=translation, + rots=rotation, + mask=torch.ones_like(translation[..., 0]), + orig='from tensor', + safe=True, + unit=unit) + + +def torsion_mask_to_atom14_mask(torsion_mask: torch.Tensor, + group_mask: torch.Tensor, + fasta: torch.Tensor) -> torch.Tensor: + """ + expand the mask of torsion angles into atom14 masks + + Args: + torsion_mask (): the mask for torsion angles, of shape (*, 8) + group_mask (): the group mask to add on, of shape (*, 14, 8) + fasta (): the sequence for this operation + + Returns: + Expanded mask of shape (*, 14) + + """ + atom14_exist_mask = group_mask[..., 1:].sum(-1) + atom14_exist_mask[..., 4] = fasta != 7 + atom14_exist_mask[..., 0:3] = torsion_mask[..., 0:1] + return atom14_exist_mask.bool() + + +# ============================================================================= +# Functions +# ============================================================================= + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/functions.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/functions.py new file mode 100644 index 000000000..3a475eac9 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/functions.py @@ -0,0 +1,149 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +This script contains some functions that may be handy somewhere +""" +# ============================================================================= +# Constants +# ============================================================================= +# ============================================================================= +# Imports +# ============================================================================= +import typing + +import torch + + +# ============================================================================= +# Functions +# ============================================================================= +def get_norm(x: torch.Tensor) -> torch.Tensor: + """ + Replacement for LA.norm since MPS does not support it yet. + + Args: + x: + + Returns: + + """ + return x.norm(p=2, dim=-1) + + +def robust_normalize(x: torch.Tensor, + dim: int = -1, + p: typing.Union[int, str] = 2) -> torch.Tensor: + """ + Normalization with a constant small term on the denominator + + Args: + x (): tensor to normalize + dim (): the dimension along which to perform the normalization + p (): the p in l-p + + Returns: + the normalized result + + """ + return x / (x.norm(p=p, dim=dim, keepdim=True).clamp(4e-5)) + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + # The following from PyTorch3d + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4) or (..., 3). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if quaternions.shape[-1] == 3: + quaternions = torch.cat( + (torch.ones_like(quaternions[..., 0:1]), quaternions), dim=-1) + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def batch_matrix_vector(matrix: torch.Tensor, + vector: torch.Tensor) -> torch.Tensor: + """ + Perform batched matrix vector product on the last dimension + + Args: + matrix (): of shape (*, d, d) + vector (): of shape (*, d) + + Returns: + the product of the two + + """ + assert len(matrix.shape[:-2]) == len(vector.shape[:-1]) + + return torch.einsum('...cd, ...d -> ...c', matrix, vector) + + +def create_pseudo_beta(atom_pos: torch.Tensor, + atom_mask: torch.Tensor) -> torch.Tensor: + """ + + Args: + atom_pos: the atom position in atom14 format, + of shape [*, num_res, 14, 3] + atom_mask: the atom mask in atom14 format, + of shape [*, num_res, 14] + + Returns: + CB coordinate (when available) and CA coordinate (when not available) + + """ + if not (atom_mask.shape[-1] == atom_pos.shape[-2] == 14): + raise ValueError('Only supports atom 14') + pseudo_beta = torch.where( + atom_mask[..., 4:5].expand(list(atom_mask.shape[:-1]) + [3]).bool(), + atom_pos[..., 4, :], atom_pos[..., 1, :]) + return pseudo_beta + + +def bit_wise_not(boolean_tensor: torch.Tensor) -> torch.Tensor: + """For MPS devices that have no support for yet bit-wise not""" + boolean_tensor = 1 - boolean_tensor.float() + return boolean_tensor.bool() + + +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/residue_constants.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/residue_constants.py new file mode 100644 index 000000000..6945d7c9e --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/residue_constants.py @@ -0,0 +1,686 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# This file is adopted from DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Constants used in OmegaFold.""" +import Bio.PDB +import torch +# Internal import (35fd). +# Distance from one CA to next CA [trans configuration: omega = 180]. +from Bio.Data import PDBData + +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don"t have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to +# torsion_angles_mask them for each AA type. The order is as per +# restype_order (see below). +chi_angles_mask = torch.tensor([ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SET + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +]) + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, +# phi, psi and chi angles: +# 0: "backbone group", +# 1: "pre-omega-group", (empty) +# 2: "phi-group", (currently empty, because it defines only hydrogens) +# 3: "psi-group", +# 4,5,6,7: "chi1,2,3,4-group" +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the +# y-axis is defined such that the dihedral-angle-definiting atom (the last +# entry in chi_angles_atoms above) is in the xy-plane (with a positive +# y-coordinate). format: [atomname, group_idx, rel_position] +aa_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ["CD", 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +for aa_k, aa_dict, in aa_atom_positions.items(): + for i, v in enumerate(aa_dict): + aa_dict[i][-1] = torch.tensor(v[-1]) + aa_atom_positions[aa_k] = aa_dict + +# This mapping is used when we need to store atom data in a format that +# requires fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': [ + 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', + '', '' + ], + 'ASN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': [ + 'N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', + '', '' + ], + 'ILE': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': [ + 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', + '', '' + ], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': + ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': [ + 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', + 'CZ2', 'CZ3', 'CH2' + ], + 'TYR': [ + 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', + 'OH', '', '' + ], + 'VAL': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X', '-'] +restype_order_with_x = { + restype: i + for i, restype in enumerate(restypes_with_x) +} +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', + 'X': 'UNK' +} + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including "X" and "U" which we don"t use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +restype2atom_mask = torch.zeros([len(restypes_with_x), 14]) +for k, v in restype_name_to_atom14_names.items(): + for i, atom in enumerate(v): + restype2atom_mask[restype_order_with_x[ + restype_3to1[k]]][i] = len(atom) > 0 + +restype_rigidgroup_mask = torch.zeros([21, 8], dtype=torch.float) +restype_rigidgroup_mask[:, 0] = 1 +restype_rigidgroup_mask[:, 3] = 1 +restype_rigidgroup_mask[:, 4:] = chi_angles_mask + + +# Compute a mask whether the group exists. +# (N, 8) +def residx_to_3(idx): + return restype_1to3[restypes[idx]] + + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +def get_chi_angle_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue + types are in the order specified in residue_constants.restypes + + unknown residue type at the end. For chi angles which are not defined + on the residue, the positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([atom_order[_atom] for _atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For those not defined on AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return torch.tensor(chi_atom_indices) + + +chi_angle_atom_indices = get_chi_angle_atom_indices() + + +def _make_rigid_transformation_4x4(ex: torch.Tensor, ey: torch.Tensor, + translation: torch.Tensor) -> torch.Tensor: + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / torch.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - torch.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= torch.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = torch.cross(ex_normalized, ey_normalized) + m = torch.stack([ex_normalized, ey_normalized, eznorm, translation]).T + m = torch.cat([m, torch.tensor([[0., 0., 0., 1.]])], dim=0) + return m + + +# create an array with (restype, atomtype) --> aa_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_aa = torch.zeros([21, 37], dtype=torch.long) +restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32) +restype_atom37_aa_positions = torch.zeros([21, 37, 3], dtype=torch.float32) +restype_atom14_to_aa = torch.zeros([21, 14], dtype=torch.long) +restype_atom14_mask = torch.zeros([21, 14], dtype=torch.float32) +restype_atom14_aa_positions = torch.zeros([21, 14, 3], dtype=torch.float32) +restype_aa_default_frame = torch.zeros([21, 8, 4, 4], dtype=torch.float32) + + +def _make_aa_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_pos in aa_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_aa[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_aa_positions[restype, atomtype, :] = atom_pos + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_aa[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_aa_positions[restype, atom14idx, :] = atom_pos + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: pos + for name, _, pos in aa_atom_positions[resname] + } + + # backbone to backbone is the identity transforms + restype_aa_default_frame[restype, 0, :, :] = torch.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_aa_default_frame[restype, 1, :, :] = torch.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4(ex=atom_positions['N'] - + atom_positions['CA'], + ey=torch.tensor([1., 0., 0.]), + translation=atom_positions['N']) + restype_aa_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_aa_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_aa_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=torch.tensor([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_aa_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_aa_constants() +"""Construct denser atom positions (14 dimensions instead of 37).""" +restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 +restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + +for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) + for name in atom_names]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + +# Add dummy mapping for restype "UNK" +restype_atom14_to_atom37.append([0] * 14) +restype_atom37_to_atom14.append([0] * 37) + +restype_atom14_to_atom37 = torch.tensor(restype_atom14_to_atom37, + dtype=torch.long) +restype_atom37_to_atom14 = torch.tensor(restype_atom37_to_atom14, + dtype=torch.long) +chi_pi_periodic = torch.tensor([ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SET + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +]) + +residue_atom_renaming_swaps = { + 'ASP': { + 'OD1': 'OD2' + }, + 'GLU': { + 'OE1': 'OE2' + }, + 'PHE': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, + 'TYR': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, +} + +# Create an ambiguous atoms mask. shape: (21, 14). +mask_ambiguous = torch.zeros((21, 14), dtype=torch.bool) +for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2) + mask_ambiguous[restype, atom_idx1] = 1 + mask_ambiguous[restype, atom_idx2] = 1 + +restype_3 = [restype_1to3[res] for res in restypes] +restype_3 += ['UNK'] + +all_matrices = {res: torch.eye(14, dtype=torch.float32) for res in restype_3} +for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = torch.arange(14) + renaming_matrix = None + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[resname].index( + source_atom_swap) + target_index = restype_name_to_atom14_names[resname].index( + target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = torch.zeros((14, 14), dtype=torch.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.to(torch.float32) +renaming_matrices = torch.stack( + [all_matrices[restype] for restype in restype_3], dim=0) + + +def substitute(res: str): + if Bio.PDB.is_aa(res): + if res in resnames: + return res + else: + res = PDBData.protein_letters_3to1[res] + if res in restype_1to3.keys(): + return restype_1to3[res] + elif res == 'X': + return 'UNK' + else: + # did not get anything that works + return None diff --git a/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/torch_utils.py b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/torch_utils.py new file mode 100644 index 000000000..93307d603 --- /dev/null +++ b/opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/torch_utils.py @@ -0,0 +1,147 @@ +# ============================================================================= +# Copyright 2022 HeliXon Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +PyTorch utilities +""" +# ============================================================================= +# Imports +# ============================================================================= +import numbers +import typing + +import torch +from torch.nn import functional as F + +# ============================================================================= +# Constants +# ============================================================================= + +T = typing.TypeVar('T') + + +# ============================================================================= +# Functions +# ============================================================================= +def mask2bias(mask: torch.Tensor, *, inf: float = 1e9) -> torch.Tensor: + """Convert mask to attention bias + + Args: + mask: the mask to convert to bias representation + inf: the floating point number to represent infinity + + Returns: + bias representation for masking in attention + + """ + return mask.float().sub(1).mul(inf) + + +def normalize(inputs: torch.Tensor, + normalized_shape: typing.Optional[typing.Union[ + int, typing.List[int], torch.Size]] = None, + in_place: bool = False) -> torch.Tensor: + """Layer normalization without a module (and weight) + + Args: + inputs: the input tensor to be normalized + normalized_shape: the normalized_shape for normalization + in_place: if to perform the operations in-place + + Returns: + normalized tensor + + """ + if normalized_shape is None: + normalized_shape = inputs.shape[-1] + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape, ) + + if in_place: + # This seems to create small discrepancy in result + dim = list(range(len(inputs.shape))[-len(normalized_shape):]) + inputs -= inputs.mean(dim=dim, keepdim=True) + inputs *= torch.rsqrt(inputs.var(dim=dim, keepdim=True) + 1e-5) + return inputs + else: + # F.layer_norm seems a bit faster + return F.layer_norm(inputs, normalized_shape, None, None, 1e-5) + + +def masked_mean(values: torch.Tensor, + mask: torch.Tensor, + dim: typing.Union[int, typing.Sequence[int], None], + keepdim: typing.Optional[bool] = False, + eps: typing.Optional[float] = 4e-5) -> torch.Tensor: + """Mean operation with mask + + Args: + values: the values to take the mean for + mask: the mask to take the mean with + dim: the dimension along which to take the mean + keepdim: to keep the dimension + eps: the epsilon to compute mean for + + Returns: + mean result + + """ + values = values.masked_fill(~mask.bool(), 0).sum(dim, keepdim=keepdim) + norm = mask.sum(dim, keepdim=keepdim, dtype=values.dtype) + eps + return values / norm + + +def recursive_to(obj: typing.Any, **kwargs) -> typing.Any: + r""" + Just to move things to space + *args is removed because it brings problems in using .cpu() + + Args: + obj (): the object to move + kwargs (): different keyword arguments + + Returns: + cuda tensors in its original construct + + """ + if isinstance(obj, torch.Tensor): + try: + return obj.to(**kwargs) + except RuntimeError: + kwargs.pop('non_blocking') + return obj.to(**kwargs) + elif isinstance(obj, list): + return [recursive_to(o, **kwargs) for o in obj] + elif isinstance(obj, tuple): + return tuple(recursive_to(o, **kwargs) for o in obj) + elif isinstance(obj, set): + return set(recursive_to(o, **kwargs) for o in obj) + elif isinstance(obj, dict): + return {k: recursive_to(v, **kwargs) for k, v in obj.items()} + elif hasattr(obj, 'to'): + # this takes care of classes that implements the ~to method + return obj.to(**kwargs) + else: + return obj + + +# ============================================================================= +# Classes +# ============================================================================= +# ============================================================================= +# Tests +# ============================================================================= +if __name__ == '__main__': + pass diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b51a9172e..bc0aa3e4a 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -156,6 +156,7 @@ from .scicode import * # noqa: F401, F403 from .SciEval import SciEvalDataset # noqa: F401 from .SciKnowEval import * # noqa: F401, F403 +from .SciReasoner import * # noqa: F401, F403 from .SeedBench import * # noqa: F401, F403 from .simpleqa import * # noqa: F401, F403 from .siqa import * # noqa: F401, F403 diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 48f891a59..ef097f9c9 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -135,13 +135,14 @@ def _load_and_preprocess_test_data(self): test_set = build_dataset_from_cfg(self.dataset_cfg).test # Postprocess dataset if necessary if 'dataset_postprocessor' in self.eval_cfg: - proc = self.eval_cfg['dataset_postprocessor']['type'] + kwargs = copy.deepcopy(self.eval_cfg['dataset_postprocessor']) + proc = kwargs.pop('type') if isinstance(proc, str): proc = TEXT_POSTPROCESSORS.get(proc) def postprocess(sample): s = sample[self.output_column] - sample[self.output_column] = proc(s) + sample[self.output_column] = proc(s, **kwargs) return sample test_set = test_set.map(postprocess) diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 1b11fdefd..51b1a27d9 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -542,6 +542,48 @@ "hf_id": "", "local": "./data/phybench", }, + + # SciReasoner + "opencompass/SciReasoner-bio_instruction":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/bio_instruction", + }, + "opencompass/SciReasoner-Conditional_generation":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/Conditional_generation", + }, + "opencompass/SciReasoner-GUE":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/GUE-test", + }, + "opencompass/SciReasoner-LLM4Mat":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/LLM4Mat-test", + }, + "opencompass/SciReasoner-Mol_Instructions":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/Mol-Instructions-test", + }, + "opencompass/SciReasoner-OPI":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/OPI_test", + }, + "opencompass/SciReasoner-PEER":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/PEER-test", + }, + "opencompass/SciReasoner-smol":{ + "ms_id": "", + "hf_id": "", + "local": "./data/SciReasoner/smol-test", + }, } DATASETS_URL = { diff --git a/requirements/extra.txt b/requirements/extra.txt index a452cc924..9b553818c 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -2,6 +2,9 @@ alpaca-eval==0.6 # OlympiadBench antlr4-python3-runtime==4.11 +# UPG +Bio +# OlympiadBench cn2an # Dingo dingo-python==1.5.0 @@ -23,10 +26,14 @@ pint pyext # Law Bench pypinyin +# LLM4Chem +rdchiral # Smolinstruct rdkit # Molinstructions selfies +# Scireasoner Composition Material +smact # IFBench syllapy # RULER