44from ecg_bench .config import get_args
55
66def extract_file_info (file ):
7- parts = file .split ('_' )
8- rag_used = parts [- 2 ] == 'True'
9- rag_k = int (parts [- 1 ].split ('.' )[0 ]) if rag_used else None
10- is_seed = 'seed' in file
11- seed_num = int (file .split ('/' )[- 1 ].split ('_' )[1 ]) if is_seed else None
12- return rag_used , rag_k , is_seed , seed_num
7+ filename = file .split ('/' )[- 1 ]
8+ parts = filename .split ('_' )
9+
10+ if filename .startswith ('seed_' ):
11+ # seed_{seed}_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
12+ seed_num = int (parts [1 ])
13+ perturb = parts [2 ]
14+ rag_used = parts [3 ] == 'True'
15+
16+ if rag_used :
17+ retrieval_base = parts [4 ]
18+ retrieved_information = parts [5 ]
19+ rag_k = int (parts [6 ])
20+ rag_prompt_mode = parts [7 ]+ parts [8 ]
21+ normalized_rag_feature = parts [9 ].split ('.' )[0 ]
22+ else :
23+ retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
24+ rag_k = None
25+
26+ is_seed = True
27+ else :
28+ # statistical_results_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
29+ perturb = parts [2 ]
30+ rag_used = parts [3 ] == 'True'
31+
32+ if rag_used :
33+ retrieval_base = parts [4 ]
34+ retrieved_information = parts [5 ]
35+ rag_k = int (parts [6 ])
36+ rag_prompt_mode = parts [7 ]+ parts [8 ]
37+ normalized_rag_feature = parts [9 ].split ('.' )[0 ]
38+ else :
39+ retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
40+ rag_k = None
41+
42+ is_seed = False
43+ seed_num = None
44+
45+ return {
46+ 'rag_used' : rag_used ,
47+ 'rag_k' : rag_k ,
48+ 'is_seed' : is_seed ,
49+ 'seed_num' : seed_num ,
50+ 'perturb' : perturb ,
51+ 'retrieval_base' : retrieval_base ,
52+ 'retrieved_information' : retrieved_information ,
53+ 'rag_prompt_mode' : rag_prompt_mode ,
54+ 'normalized_rag_feature' : normalized_rag_feature
55+ }
1356
1457def process_seed_data (data ):
1558 averages = data ['averages' ]
@@ -28,42 +71,52 @@ def collect_results(json_files):
2871 statistical_no_rag = {}
2972 individual_seeds_rag = defaultdict (dict )
3073 statistical_rag = {}
74+ config_info_no_rag = None
75+ config_info_rag = {}
3176
3277 for file in json_files :
33- rag_used , rag_k , is_seed , seed_num = extract_file_info (file )
78+ info = extract_file_info (file )
3479 with open (file , 'r' ) as f :
3580 data = json .load (f )
3681
37- if is_seed :
82+ if info [ ' is_seed' ] :
3883 metrics = process_seed_data (data )
39- if rag_used :
40- individual_seeds_rag [rag_k ][seed_num ] = metrics
84+ if info ['rag_used' ]:
85+ individual_seeds_rag [info ['rag_k' ]][info ['seed_num' ]] = metrics
86+ config_info_rag [info ['rag_k' ]] = info
4187 else :
42- individual_seeds_no_rag [seed_num ] = metrics
88+ individual_seeds_no_rag [info ['seed_num' ]] = metrics
89+ config_info_no_rag = info
4390 else :
44- if rag_used :
45- statistical_rag [rag_k ] = data
91+ if info ['rag_used' ]:
92+ statistical_rag [info ['rag_k' ]] = data
93+ config_info_rag [info ['rag_k' ]] = info
4694 else :
4795 statistical_no_rag = data
96+ config_info_no_rag = info
4897
4998 return (individual_seeds_no_rag , statistical_no_rag ,
50- individual_seeds_rag , statistical_rag )
99+ individual_seeds_rag , statistical_rag , config_info_no_rag , config_info_rag )
51100
52- def print_seed_results (title , seed_dict ):
101+ def print_seed_results (title , seed_dict , config_info = None ):
53102 if not seed_dict :
54103 return
55104 print (title )
105+ if config_info :
106+ print (f" Config: perturb={ config_info ['perturb' ]} , retrieval_base={ config_info ['retrieval_base' ]} , retrieved_info={ config_info ['retrieved_information' ]} , prompt_mode={ config_info ['rag_prompt_mode' ]} , normalized={ config_info ['normalized_rag_feature' ]} " )
56107 for seed in sorted (seed_dict .keys ()):
57108 print (f" Seed { seed } :" )
58109 for metric in ['BLEU' , 'METEOR' , 'ROUGE' , 'BERTSCORE' , 'ACC' ]:
59110 value = seed_dict [seed ][metric ] * 100 # Scale to 0-100
60111 print (f" { metric } : { value :.2f} " )
61112 print ('--------------------------------' )
62113
63- def print_statistical_results (title , stats_dict ):
114+ def print_statistical_results (title , stats_dict , config_info = None ):
64115 if not stats_dict :
65116 return
66117 print (title )
118+ if config_info :
119+ print (f" Config: perturb={ config_info ['perturb' ]} , retrieval_base={ config_info ['retrieval_base' ]} , retrieved_info={ config_info ['retrieved_information' ]} , prompt_mode={ config_info ['rag_prompt_mode' ]} , normalized={ config_info ['normalized_rag_feature' ]} " )
67120 for metric in ['BLEU' , 'METEOR' , 'ROUGE' , 'BERTSCORE' , 'ACC' ]:
68121 value = (stats_dict ['ROUGE' ]['rouge-l' ] if metric == 'ROUGE' else
69122 stats_dict ['BERTSCORE' ]['hf-f1' ] if metric == 'BERTSCORE' else
@@ -89,14 +142,15 @@ def main():
89142 return
90143
91144 (individual_seeds_no_rag , statistical_no_rag ,
92- individual_seeds_rag , statistical_rag ) = collect_results (json_files )
145+ individual_seeds_rag , statistical_rag , config_info_no_rag , config_info_rag ) = collect_results (json_files )
93146
94- print_seed_results ("Individual Seed Results without RAG:" , individual_seeds_no_rag )
95- print_statistical_results ("Statistical Results without RAG:" , statistical_no_rag )
147+ print_seed_results ("Individual Seed Results without RAG:" , individual_seeds_no_rag , config_info_no_rag )
148+ print_statistical_results ("Statistical Results without RAG:" , statistical_no_rag , config_info_no_rag )
96149
97150 for k in sorted (individual_seeds_rag .keys ()):
98- print_seed_results (f"Individual Seed Results with RAG k={ k } :" , individual_seeds_rag [k ])
99- print_statistical_results (f"Statistical Results with RAG k={ k } :" , statistical_rag .get (k , {}))
151+ config_info = config_info_rag .get (k )
152+ print_seed_results (f"Individual Seed Results with RAG k={ k } :" , individual_seeds_rag [k ], config_info )
153+ print_statistical_results (f"Statistical Results with RAG k={ k } :" , statistical_rag .get (k , {}), config_info )
100154
101155 print ('================================================' )
102156
0 commit comments