-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain_tables.py
More file actions
57 lines (52 loc) · 2.18 KB
/
main_tables.py
File metadata and controls
57 lines (52 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pandas as pd
import argparse
import time
from trainer import train
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-d', type=str)
a=parser.parse_args()
datasets=['cifar224','imageneta']
if a.d not in datasets:
print('Dataset '+a.d+' not supported')
return
#data in Table 1 and Table 2
if a.d in ['core50','cddb','domainnet']:
IDS=[7,6,4,3,5,2]
ID_names=['Algorithm 1','No RPs','No Phase 1','No RPs or Phase 1','NCM with Phase 1','NCM only','runtime']
else:
IDS=[0,7,6,4,3,5,2,1]
ID_names=['Joint linear probe','Algorithm 1','No RPs','No Phase 1','No RPs or Phase 1','NCM with Phase 1','NCM only','Joint full fine-tuning','runtime']
results=pd.DataFrame(columns=[a.d],index=ID_names)
for d in [a.d]:
t0=time.time()
for idx,i in enumerate(IDS):
exps=pd.read_csv('./args/'+d+'_publish.csv')
args=exps[exps['ID']==i].to_dict('records')[0]
args['seed']=[args['seed']]
args['device']=[args['device']]
args['do_not_save']=True
ave_accs=train(args)
results.at[ID_names[idx],d]=ave_accs[0][-1]
results.at['runtime',d]=time.time()-t0
results.to_csv('paper_tables/Table_data_main_'+d+'.csv')
#data in Tables A6, A7 and A8
IDS=[10,9,8,13,12,11,16,15,14]
ID_names=['ResNet50, Phase 2','ResNet50, No RPs','ResNet50, NCM only']
ID_names+=['ResNet152, Phase 2','ResNet152, No RPs','ResNet152, NCM only']
ID_names+=['CLIP, Phase 2','CLIP, No RPs','CLIP, NCM only']
results=pd.DataFrame(columns=[a.d],index=ID_names)
for d in [a.d]:
t0=time.time()
for idx,i in enumerate(IDS):
exps=pd.read_csv('./args/'+d+'_publish.csv')
args=exps[exps['ID']==i].to_dict('records')[0]
args['seed']=[args['seed']]
args['device']=[args['device']]
args['do_not_save']=True
ave_accs=train(args)
results.at[ID_names[idx],d]=ave_accs[0][-1]
results.at['runtime',d]=time.time()-t0
results.to_csv('paper_tables/Table_data_si_'+d+'.csv')
if __name__ == '__main__':
main()