@@ -71,27 +71,31 @@ def train(epoch, epochs, model, dataloader, criterion, optimizer, scheduler = No
7171 running_accuracy = 0.0
7272 device = 'cuda' if torch .cuda .is_available () else 'cpu'
7373
74-
74+ model . train ()
7575 train_step = len (dataloader )
7676 with tqdm (total = train_step ,desc = f'Train Epoch { epoch + 1 } /{ epochs } ' ,postfix = dict ,mininterval = 0.3 ) as pbar :
7777 for step ,(data , target ) in enumerate (dataloader ):
7878 data = data .to (device )
7979 target = target .to (device )
80+ #---------------------
81+ # 释放内存
82+ #---------------------
83+ if hasattr (torch .cuda , 'empty_cache' ):
84+ torch .cuda .empty_cache ()
85+ optimizer .zero_grad ()
86+
8087 output = model (data )
8188 loss = criterion (output , target )
82-
83- optimizer .zero_grad ()
84- loss .backward ()
85- optimizer .step ()
89+
8690
8791 acc = get_acc (output ,target )
8892 running_accuracy += acc
8993 running_loss += loss .data
90-
91- lr = optimizer .param_groups [0 ]['lr' ]
94+ loss .backward ()
95+ optimizer .step ()
96+
9297 pbar .set_postfix (** {'Train Acc' : running_accuracy .item ()/ (step + 1 ),
93- 'Train Loss' :running_loss .item ()/ (step + 1 ),
94- 'Lr' : lr })
98+ 'Train Loss' :running_loss .item ()/ (step + 1 )})
9599 pbar .update (1 )
96100 if scheduler :
97101 scheduler .step (running_loss )
@@ -122,8 +126,13 @@ def evaluation(epoch, epochs, model, dataloader, criterion):
122126 for step ,(data , target ) in enumerate (dataloader ):
123127 data = data .to (device )
124128 target = target .to (device )
125-
129+ #---------------------
130+ # 释放内存
131+ #---------------------
132+ if hasattr (torch .cuda , 'empty_cache' ):
133+ torch .cuda .empty_cache ()
126134 output = model (data )
135+
127136 loss = criterion (output , target )
128137 acc = get_acc (output ,target )
129138
0 commit comments