forked from yuziGuo/PolyFilterPlayground
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtune_backbone.py
More file actions
120 lines (94 loc) · 4 KB
/
tune_backbone.py
File metadata and controls
120 lines (94 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import optuna
from optuna.trial import TrialState
import optuna.study
from utils.grading_logger import _set_logger
import ipdb
import numpy as np
from opts.tune.public_hypers import public_hypers_default
from opts.tune.public_hypers import convert_dict_to_optuna_suggested
from utils.optuna_utils import _ckpt_fname
from utils.optuna_utils import _get_complete_and_pruned_trial_nums
from utils.optuna_utils import _pruneDuplicate, _CkptsAndHandlersClearerCallBack
def main():
return np.random.rand()
def initialize_args():
# 1. Set static args
## 1.1 static options shared by all tasks
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument("--model", type=str, default='OptBasisGNN')
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--lcc", action='store_true', default=False)
## log options
parser.add_argument("--logging", action='store_true', default=False)
parser.add_argument("--log-detail", action='store_true', default=False)
parser.add_argument("--log-detailedCh", action='store_true', default=False)
parser.add_argument("--id-log", type=int, default=0)
##
parser.add_argument("--optuna-n-trials", type=int, default=202)
static_args = parser.parse_args()
if static_args.gpu < 0:
static_args.gpu = 'cpu'
## 1.2 Static options shared by all tasks (Part II)
dargs = vars(static_args)
from opts.tune.public_static_settings import public_static_opts
dargs.update(public_static_opts)
# 2. Args to be tuned
# Other options to be tuned will be suggested by optuna.
# For such case, we initialize a `suggestor' here, which wraps functions provided by optuna like `trial.suggest_float'.
# The suggestor suggests a group of option in a specific run (See function `objective').
# Most of the options are shared across different models, i.e., learning rates, weight decays.
suggestor = convert_dict_to_optuna_suggested(public_hypers_default)
return static_args, suggestor
def objective(trial):
# arguments
suggested_args = suggestor(trial)
# args = {} # create an empty namespace object
args = argparse.Namespace()
dargs = vars(args)
dargs.update(vars(static_args))
dargs.update(suggested_args)
dargs.update({'es_ckpt': _ckpt_fname(trial.study, trial)})
# logger
logger = _set_logger(args)
logger.info(args)
# might prune; in this case an exception will be raised
_pruneDuplicate(trial)
# report args
# run
val_acc = main()
trial.set_user_attr("val_acc", val_acc)
return val_acc
if __name__ == '__main__':
global static_args
global suggestor
static_args, suggestor = initialize_args()
# create an optuna study
dataset = static_args.dataset
kw = f'noBN-{static_args.dataset}'
study = optuna.create_study(
study_name="GATClenRes-{}".format(dataset),
direction="maximize",
storage = optuna.storages.RDBStorage(url='sqlite:///{}/GATClenRes-{}.db'.format('cache/OptunaTrials', kw),
engine_kwargs={"connect_args": {"timeout": 10000}}),
pruner=optuna.pruners.MedianPruner(n_startup_trials=5,n_warmup_steps=15,interval_steps=1,n_min_trials=5),
load_if_exists=True
)
study.set_system_attr('kw', kw)
# run trials
n_trials = static_args.optuna_n_trials
num_completed, num_pruned = _get_complete_and_pruned_trial_nums(study)
while num_completed + num_pruned < n_trials:
print('{} trials to go!'.format(n_trials - num_completed - num_pruned))
# One trial each time
study.optimize(objective,
n_trials=1,
catch=(RuntimeError,),
callbacks=(_CkptsAndHandlersClearerCallBack(),)
)
num_completed, num_pruned = _get_complete_and_pruned_trial_nums(study)
if num_pruned > 1000:
break
# report results