-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
125 lines (96 loc) · 3.84 KB
/
utils.py
File metadata and controls
125 lines (96 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import random
import numpy as np
import torch
from argparse import ArgumentParser
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def count_acc(logits, label):
pred = torch.argmax(logits, dim=1)
if torch.cuda.is_available():
return (pred == label).type(torch.cuda.FloatTensor).mean().item() * 100
else:
return (pred == label).type(torch.FloatTensor).mean().item() * 100
def is_debug():
import sys
gettrace = getattr(sys, 'gettrace', None)
if gettrace is None:
return False
elif gettrace():
return True
else:
return False
def update_ema_variables(ema_model, model, alpha):
# Use the true average until the exponential average is more correct
ema_dict = ema_model.state_dict()
model_dict = model.state_dict()
for k, v in model_dict.items():
ema_dict[k] = alpha * ema_dict[k] + (1 - alpha) * v
ema_model.load_state_dict(ema_dict)
def add_arguments(parser: ArgumentParser):
# add PROGRAM level args
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--max_epochs', type=int, default=200)
parser.add_argument('--model', type=str, default='ce')
parser.add_argument('--network', type=str, default='resnet18')
parser.add_argument('--init_weights', type=str, default=None)
parser.add_argument('--eval', action='store_true', default=False)
parser.add_argument('--eval_epochs', type=int, default=500)
parser.add_argument('--val_every_n_epoch', type=int, default=1)
# add optimizer args
parser.add_argument('--optim', type=str, default='sgd', choices=['sgd', 'adam'])
parser.add_argument('--lr', type=float, default=0.03)
parser.add_argument('--scheduler', type=str, default='step', choices=['step', 'cosine'])
parser.add_argument('--point', type=int, nargs='+', default=(100, 50, 20))
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--wd', type=float, default=0.0005) # weight decay
parser.add_argument('--mo', type=float, default=0.9) # momentum
parser.add_argument('--warmup', action='store_true', default=False)
# add dataset args
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--data_dir', type=str, default='./dataset')
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--train_transform', type=str, default='simclr')
return parser
def set_gpu(args):
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
# os.environ['MASTER_PORT'] = '10129'
print('using gpu:', args.gpus)
gpus = range(len(args.gpus.split(',')))
args.gpus = ','.join(str(g) for g in gpus)
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def model_load_weights(init_weights, model):
model_state_dict = model.state_dict()
try:
pretrained_dict = torch.load(init_weights)
except:
import pickle
with open(init_weights, 'rb') as fp:
pretrained_dict = pickle.load(fp)
keys = ['params', 'state_dict']
for k in keys:
if k in pretrained_dict:
pretrained_dict = pretrained_dict[k]
break
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
print(pretrained_dict.keys())
model_state_dict.update(pretrained_dict)
model.load_state_dict(model_state_dict)