Skip to content

Commit 56b1a5c

Browse files
committed
update utils.py
1 parent 7e749ce commit 56b1a5c

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

CIFAR10_code/utils_2.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
2+
'''
3+
对训练函数进行更新
4+
可视化更加方便,更加直观
5+
'''
6+
import os
7+
import matplotlib.pyplot as plt
8+
from tqdm import tqdm
9+
import torch
10+
def get_acc(outputs, label):
11+
total = outputs.shape[0]
12+
probs, pred_y = outputs.data.max(dim=1) # 得到概率
13+
correct = (pred_y == label).sum().data
14+
return correct / total
15+
16+
def plot_history(epochs, Acc = None, Loss=None, lr=None):
17+
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
18+
plt.style.use('seaborn')
19+
20+
if Acc or Loss or lr:
21+
if not os.path.isdir('vis'):
22+
os.mkdir('vis')
23+
epoch_list = range(1,epochs + 1)
24+
25+
if Loss:
26+
plt.plot(epoch_list, Loss['train_loss'])
27+
plt.plot(epoch_list, Loss['val_loss'])
28+
plt.xlabel('epoch')
29+
plt.ylabel('Loss Value')
30+
plt.legend(['train', 'val'], loc='upper left')
31+
plt.savefig('vis/history_Loss.png')
32+
plt.show()
33+
34+
if Acc:
35+
plt.plot(epoch_list, Acc['train_acc'])
36+
plt.plot(epoch_list, Acc['val_acc'])
37+
plt.xlabel('epoch')
38+
plt.ylabel('Acc Value')
39+
plt.legend(['train', 'val'], loc='upper left')
40+
plt.savefig('vis/history_Acc.png')
41+
plt.show()
42+
43+
if lr:
44+
plt.plot(epoch_list, lr)
45+
plt.xlabel('epoch')
46+
plt.ylabel('Train LR')
47+
plt.savefig('vis/history_Lr.png')
48+
plt.show()
49+
50+
51+
def train(epoch, epochs, model, dataloader, criterion, optimizer, scheduler = None):
52+
53+
'''
54+
Function used to train the model over a single epoch and update it according to the
55+
calculated gradients.
56+
57+
Args:
58+
model: Model supplied to the function
59+
dataloader: DataLoader supplied to the function
60+
criterion: Criterion used to calculate loss
61+
optimizer: Optimizer used update the model
62+
scheduler: Scheduler used to update the learing rate for faster convergence
63+
(Commented out due to poor results)
64+
resnet_features: Model to get Resnet Features for the hybrid architecture (Default=None)
65+
66+
Output:
67+
running_loss: Training Loss (Float)
68+
running_accuracy: Training Accuracy (Float)
69+
'''
70+
running_loss = 0.0
71+
running_accuracy = 0.0
72+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
73+
74+
75+
train_step = len(dataloader)
76+
with tqdm(total=train_step,desc=f'Train Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:
77+
for step,(data, target) in tqdm(dataloader):
78+
data = data.to(device)
79+
target = target.to(device)
80+
output = model(data)
81+
loss = criterion(output, target)
82+
83+
optimizer.zero_grad()
84+
loss.backward()
85+
optimizer.step()
86+
87+
acc = get_acc(output,target)
88+
running_accuracy += acc
89+
running_loss += loss.data
90+
91+
lr = optimizer.param_groups[0]['lr']
92+
pbar.set_postfix(**{'Train Acc' : running_accuracy.item()/(step+1),
93+
'Train Loss' :running_loss.item()/(step+1),
94+
'Lr' : lr})
95+
pbar.update(1)
96+
if scheduler:
97+
scheduler.step(running_loss)
98+
running_loss, running_accuracy = running_loss/len(dataloader), running_accuracy/len(dataloader)
99+
return running_loss, running_accuracy
100+
101+
102+
def evaluation(epoch, epochs, model, dataloader, criterion):
103+
'''
104+
Function used to evaluate the model on the test dataset.
105+
106+
Args:
107+
model: Model supplied to the function
108+
dataloader: DataLoader supplied to the function
109+
criterion: Criterion used to calculate loss
110+
resnet_features: Model to get Resnet Features for the hybrid architecture (Default=None)
111+
112+
Output:
113+
test_loss: Testing Loss (Float)
114+
test_accuracy: Testing Accuracy (Float)
115+
'''
116+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
117+
eval_step = len(dataloader)
118+
with torch.no_grad():
119+
test_accuracy = 0.0
120+
test_loss = 0.0
121+
with tqdm(total=eval_step,desc=f'Evaluation Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:
122+
for step,(data, target) in tqdm(dataloader):
123+
data = data.to(device)
124+
target = target.to(device)
125+
126+
output = model(data)
127+
loss = criterion(output, target)
128+
acc = get_acc(output,target)
129+
130+
test_accuracy += acc
131+
test_loss += loss.item()
132+
133+
pbar.set_postfix(**{'Eval Acc' : test_accuracy.item()/(step+1),
134+
'Eval Loss' :test_loss.item()/(step+1)})
135+
pbar.update(1)
136+
137+
test_loss, test_accuracy = test_loss/eval_step, test_accuracy/eval_step
138+
return test_loss, test_accuracy

0 commit comments

Comments
 (0)