1+
2+ '''
3+ 对训练函数进行更新
4+ 可视化更加方便,更加直观
5+ '''
6+ import os
7+ import matplotlib .pyplot as plt
8+ from tqdm import tqdm
9+ import torch
10+ def get_acc (outputs , label ):
11+ total = outputs .shape [0 ]
12+ probs , pred_y = outputs .data .max (dim = 1 ) # 得到概率
13+ correct = (pred_y == label ).sum ().data
14+ return correct / total
15+
16+ def plot_history (epochs , Acc = None , Loss = None , lr = None ):
17+ plt .rcParams ['figure.figsize' ] = (10.0 , 8.0 ) # set default size of plots
18+ plt .style .use ('seaborn' )
19+
20+ if Acc or Loss or lr :
21+ if not os .path .isdir ('vis' ):
22+ os .mkdir ('vis' )
23+ epoch_list = range (1 ,epochs + 1 )
24+
25+ if Loss :
26+ plt .plot (epoch_list , Loss ['train_loss' ])
27+ plt .plot (epoch_list , Loss ['val_loss' ])
28+ plt .xlabel ('epoch' )
29+ plt .ylabel ('Loss Value' )
30+ plt .legend (['train' , 'val' ], loc = 'upper left' )
31+ plt .savefig ('vis/history_Loss.png' )
32+ plt .show ()
33+
34+ if Acc :
35+ plt .plot (epoch_list , Acc ['train_acc' ])
36+ plt .plot (epoch_list , Acc ['val_acc' ])
37+ plt .xlabel ('epoch' )
38+ plt .ylabel ('Acc Value' )
39+ plt .legend (['train' , 'val' ], loc = 'upper left' )
40+ plt .savefig ('vis/history_Acc.png' )
41+ plt .show ()
42+
43+ if lr :
44+ plt .plot (epoch_list , lr )
45+ plt .xlabel ('epoch' )
46+ plt .ylabel ('Train LR' )
47+ plt .savefig ('vis/history_Lr.png' )
48+ plt .show ()
49+
50+
51+ def train (epoch , epochs , model , dataloader , criterion , optimizer , scheduler = None ):
52+
53+ '''
54+ Function used to train the model over a single epoch and update it according to the
55+ calculated gradients.
56+
57+ Args:
58+ model: Model supplied to the function
59+ dataloader: DataLoader supplied to the function
60+ criterion: Criterion used to calculate loss
61+ optimizer: Optimizer used update the model
62+ scheduler: Scheduler used to update the learing rate for faster convergence
63+ (Commented out due to poor results)
64+ resnet_features: Model to get Resnet Features for the hybrid architecture (Default=None)
65+
66+ Output:
67+ running_loss: Training Loss (Float)
68+ running_accuracy: Training Accuracy (Float)
69+ '''
70+ running_loss = 0.0
71+ running_accuracy = 0.0
72+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
73+
74+
75+ train_step = len (dataloader )
76+ with tqdm (total = train_step ,desc = f'Train Epoch { epoch + 1 } /{ epochs } ' ,postfix = dict ,mininterval = 0.3 ) as pbar :
77+ for step ,(data , target ) in tqdm (dataloader ):
78+ data = data .to (device )
79+ target = target .to (device )
80+ output = model (data )
81+ loss = criterion (output , target )
82+
83+ optimizer .zero_grad ()
84+ loss .backward ()
85+ optimizer .step ()
86+
87+ acc = get_acc (output ,target )
88+ running_accuracy += acc
89+ running_loss += loss .data
90+
91+ lr = optimizer .param_groups [0 ]['lr' ]
92+ pbar .set_postfix (** {'Train Acc' : running_accuracy .item ()/ (step + 1 ),
93+ 'Train Loss' :running_loss .item ()/ (step + 1 ),
94+ 'Lr' : lr })
95+ pbar .update (1 )
96+ if scheduler :
97+ scheduler .step (running_loss )
98+ running_loss , running_accuracy = running_loss / len (dataloader ), running_accuracy / len (dataloader )
99+ return running_loss , running_accuracy
100+
101+
102+ def evaluation (epoch , epochs , model , dataloader , criterion ):
103+ '''
104+ Function used to evaluate the model on the test dataset.
105+
106+ Args:
107+ model: Model supplied to the function
108+ dataloader: DataLoader supplied to the function
109+ criterion: Criterion used to calculate loss
110+ resnet_features: Model to get Resnet Features for the hybrid architecture (Default=None)
111+
112+ Output:
113+ test_loss: Testing Loss (Float)
114+ test_accuracy: Testing Accuracy (Float)
115+ '''
116+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
117+ eval_step = len (dataloader )
118+ with torch .no_grad ():
119+ test_accuracy = 0.0
120+ test_loss = 0.0
121+ with tqdm (total = eval_step ,desc = f'Evaluation Epoch { epoch + 1 } /{ epochs } ' ,postfix = dict ,mininterval = 0.3 ) as pbar :
122+ for step ,(data , target ) in tqdm (dataloader ):
123+ data = data .to (device )
124+ target = target .to (device )
125+
126+ output = model (data )
127+ loss = criterion (output , target )
128+ acc = get_acc (output ,target )
129+
130+ test_accuracy += acc
131+ test_loss += loss .item ()
132+
133+ pbar .set_postfix (** {'Eval Acc' : test_accuracy .item ()/ (step + 1 ),
134+ 'Eval Loss' :test_loss .item ()/ (step + 1 )})
135+ pbar .update (1 )
136+
137+ test_loss , test_accuracy = test_loss / eval_step , test_accuracy / eval_step
138+ return test_loss , test_accuracy
0 commit comments