Skip to content

Commit 521228e

Browse files
committed
update org
1 parent aa7ff79 commit 521228e

File tree

2 files changed

+82
-22
lines changed

2 files changed

+82
-22
lines changed

ecg_bench/organize_results.py

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,55 @@
44
from ecg_bench.config import get_args
55

66
def extract_file_info(file):
7-
parts = file.split('_')
8-
rag_used = parts[-2] == 'True'
9-
rag_k = int(parts[-1].split('.')[0]) if rag_used else None
10-
is_seed = 'seed' in file
11-
seed_num = int(file.split('/')[-1].split('_')[1]) if is_seed else None
12-
return rag_used, rag_k, is_seed, seed_num
7+
filename = file.split('/')[-1]
8+
parts = filename.split('_')
9+
10+
if filename.startswith('seed_'):
11+
# seed_{seed}_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
12+
seed_num = int(parts[1])
13+
perturb = parts[2]
14+
rag_used = parts[3] == 'True'
15+
16+
if rag_used:
17+
retrieval_base = parts[4]
18+
retrieved_information = parts[5]
19+
rag_k = int(parts[6])
20+
rag_prompt_mode = parts[7]+parts[8]
21+
normalized_rag_feature = parts[9].split('.')[0]
22+
else:
23+
retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
24+
rag_k = None
25+
26+
is_seed = True
27+
else:
28+
# statistical_results_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
29+
perturb = parts[2]
30+
rag_used = parts[3] == 'True'
31+
32+
if rag_used:
33+
retrieval_base = parts[4]
34+
retrieved_information = parts[5]
35+
rag_k = int(parts[6])
36+
rag_prompt_mode = parts[7]+parts[8]
37+
normalized_rag_feature = parts[9].split('.')[0]
38+
else:
39+
retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
40+
rag_k = None
41+
42+
is_seed = False
43+
seed_num = None
44+
45+
return {
46+
'rag_used': rag_used,
47+
'rag_k': rag_k,
48+
'is_seed': is_seed,
49+
'seed_num': seed_num,
50+
'perturb': perturb,
51+
'retrieval_base': retrieval_base,
52+
'retrieved_information': retrieved_information,
53+
'rag_prompt_mode': rag_prompt_mode,
54+
'normalized_rag_feature': normalized_rag_feature
55+
}
1356

1457
def process_seed_data(data):
1558
averages = data['averages']
@@ -28,42 +71,52 @@ def collect_results(json_files):
2871
statistical_no_rag = {}
2972
individual_seeds_rag = defaultdict(dict)
3073
statistical_rag = {}
74+
config_info_no_rag = None
75+
config_info_rag = {}
3176

3277
for file in json_files:
33-
rag_used, rag_k, is_seed, seed_num = extract_file_info(file)
78+
info = extract_file_info(file)
3479
with open(file, 'r') as f:
3580
data = json.load(f)
3681

37-
if is_seed:
82+
if info['is_seed']:
3883
metrics = process_seed_data(data)
39-
if rag_used:
40-
individual_seeds_rag[rag_k][seed_num] = metrics
84+
if info['rag_used']:
85+
individual_seeds_rag[info['rag_k']][info['seed_num']] = metrics
86+
config_info_rag[info['rag_k']] = info
4187
else:
42-
individual_seeds_no_rag[seed_num] = metrics
88+
individual_seeds_no_rag[info['seed_num']] = metrics
89+
config_info_no_rag = info
4390
else:
44-
if rag_used:
45-
statistical_rag[rag_k] = data
91+
if info['rag_used']:
92+
statistical_rag[info['rag_k']] = data
93+
config_info_rag[info['rag_k']] = info
4694
else:
4795
statistical_no_rag = data
96+
config_info_no_rag = info
4897

4998
return (individual_seeds_no_rag, statistical_no_rag,
50-
individual_seeds_rag, statistical_rag)
99+
individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag)
51100

52-
def print_seed_results(title, seed_dict):
101+
def print_seed_results(title, seed_dict, config_info=None):
53102
if not seed_dict:
54103
return
55104
print(title)
105+
if config_info:
106+
print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}")
56107
for seed in sorted(seed_dict.keys()):
57108
print(f" Seed {seed}:")
58109
for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']:
59110
value = seed_dict[seed][metric] * 100 # Scale to 0-100
60111
print(f" {metric}: {value:.2f}")
61112
print('--------------------------------')
62113

63-
def print_statistical_results(title, stats_dict):
114+
def print_statistical_results(title, stats_dict, config_info=None):
64115
if not stats_dict:
65116
return
66117
print(title)
118+
if config_info:
119+
print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}")
67120
for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']:
68121
value = (stats_dict['ROUGE']['rouge-l'] if metric == 'ROUGE' else
69122
stats_dict['BERTSCORE']['hf-f1'] if metric == 'BERTSCORE' else
@@ -89,14 +142,15 @@ def main():
89142
return
90143

91144
(individual_seeds_no_rag, statistical_no_rag,
92-
individual_seeds_rag, statistical_rag) = collect_results(json_files)
145+
individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) = collect_results(json_files)
93146

94-
print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag)
95-
print_statistical_results("Statistical Results without RAG:", statistical_no_rag)
147+
print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag, config_info_no_rag)
148+
print_statistical_results("Statistical Results without RAG:", statistical_no_rag, config_info_no_rag)
96149

97150
for k in sorted(individual_seeds_rag.keys()):
98-
print_seed_results(f"Individual Seed Results with RAG k={k}:", individual_seeds_rag[k])
99-
print_statistical_results(f"Statistical Results with RAG k={k}:", statistical_rag.get(k, {}))
151+
config_info = config_info_rag.get(k)
152+
print_seed_results(f"Individual Seed Results with RAG k={k}:", individual_seeds_rag[k], config_info)
153+
print_statistical_results(f"Statistical Results with RAG k={k}:", statistical_rag.get(k, {}), config_info)
100154

101155
print('================================================')
102156

ecg_bench/scripts/org_results.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22

33
# data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250")
44
data=("ecg_instruct_45k_mapped_1250")
5+
# retrieval_base="feature"
6+
# retrieved_information="combined"
7+
# rag_k=1
8+
# rag_prompt_mode="system_prompt"
9+
# normalized_rag_features=True
10+
511
checkpoints=(
6-
"llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_report_5_False"
12+
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_combined_1_system_prompt_True_False'
713
)
814

915
for d in "${data[@]}"; do

0 commit comments

Comments
 (0)