-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
132 lines (109 loc) · 5.57 KB
/
main.py
File metadata and controls
132 lines (109 loc) · 5.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import hydra
import warnings
from omegaconf import OmegaConf
from dataset import get_dataset
from dataloader import get_dataloader
from model import get_model
from explainer import get_explainer
from trainer import get_trainer
from collections import defaultdict
import numpy as np
import random
from datetime import datetime
import shutil
warnings.filterwarnings('ignore', category=Warning)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def run(cfg, cur_round=0, total_round=1, dataset=None):
device = torch.device('cuda', index=cfg.device_id) if torch.cuda.is_available() else torch.device('cpu')
torch.set_num_threads(4)
'''first train several models'''
if not cfg.calculate_all_metrics:
for round_index in range(5):
pretrain_method_name = cfg.method.method_name[:-3] # gsat_cd -> gsat
dataloader = get_dataloader(dataset=dataset,
batch_size=getattr(cfg.method, cfg.dataset.dataset_name).batch_size)
model = get_model(getattr(cfg.method, cfg.dataset.dataset_name)).to(device)
explainer = get_explainer(cfg.method.method_name, getattr(cfg.method, cfg.dataset.dataset_name)).to(device)
'''load trainer'''
assert cfg.dataset.num_class == getattr(cfg.method, cfg.dataset.dataset_name).num_class
assert cfg.dataset.multi_label == getattr(cfg.method, cfg.dataset.dataset_name).multi_label
save_dir = cfg.save_dir
trainer = get_trainer(method_name=pretrain_method_name,
model=model,
explainer=explainer,
dataloader=dataloader,
cfg=getattr(cfg.method, cfg.dataset.dataset_name),
device=device,
save_dir=save_dir)
print(trainer.method_name)
trainer.train()
metrics = trainer.test()
print(metrics)
new_checkpoints_path = f'{trainer.checkpoints_path[:-4]}_{round_index}.pth'
shutil.copyfile(trainer.checkpoints_path, new_checkpoints_path)
'''load dataloader'''
dataloader = get_dataloader(dataset=dataset,
batch_size=getattr(cfg.method, cfg.dataset.dataset_name).batch_size)
model = get_model(getattr(cfg.method, cfg.dataset.dataset_name)).to(device)
explainer = get_explainer(cfg.method.method_name, getattr(cfg.method, cfg.dataset.dataset_name)).to(device)
'''load trainer'''
assert cfg.dataset.num_class == getattr(cfg.method, cfg.dataset.dataset_name).num_class
assert cfg.dataset.multi_label == getattr(cfg.method, cfg.dataset.dataset_name).multi_label
save_dir = cfg.save_dir
trainer = get_trainer(method_name=cfg.method.method_name,
model=model,
explainer=explainer,
dataloader=dataloader,
cfg=getattr(cfg.method, cfg.dataset.dataset_name),
device=device,
save_dir=save_dir)
print(trainer.method_name)
if cfg.calculate_all_metrics: # calculate all metrics
trainer.calculate_shd_auc_fid_acc(cfg.method.method_name, ensemble_numbers=np.arange(total_round))
exit()
'''pretrain+ft'''
trainer.train_ft(cur_index=cur_round)
'''test'''
metrics = trainer.test()
new_checkpoints_path = f'{trainer.checkpoints_path[:-4]}_{cur_round}.pth'
shutil.copyfile(trainer.checkpoints_path, new_checkpoints_path)
print(metrics)
return metrics
@hydra.main(config_path='configs', config_name='global', version_base='1.3')
def main(cfg):
return run(cfg)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--run_time', type=int, default=1, help='suggest 1 or 10')
parser.add_argument('--dataset', type=str, default='ba_2motifs', help='{ba_2motifs, mr, benzene, mutag}')
parser.add_argument('--method', type=str, default='gsat', help='{att, cal, size, gsat}')
parser.add_argument('--calculate_all_metrics', action='store_true', help='')
args = parser.parse_args()
accumulated_metrics = defaultdict(list)
with hydra.initialize(config_path="configs", version_base='1.3'):
cfg = hydra.compose(config_name="global", overrides=[f"dataset={args.dataset}", f"method={args.method}"])
# print(OmegaConf.to_yaml(cfg))
OmegaConf.set_struct(cfg, False)
cfg.calculate_all_metrics = args.calculate_all_metrics if isinstance(args.calculate_all_metrics,
bool) else False
'''load dataset'''
dataset = get_dataset(dataset_dir=cfg.dataset.dataset_root,
dataset_name=cfg.dataset.dataset_name,
data_split_ratio=cfg.dataset.get('data_split_ratio', None))
for i in range(args.run_time):
set_seed(i)
metrics = run(cfg, cur_round=i, total_round=args.run_time, dataset=dataset)
for key, value in metrics.items():
accumulated_metrics[key].append(value)
average_metrics = {key: (np.mean(values), np.std(values)) for key, values in accumulated_metrics.items()}
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
print(average_metrics)