Skip to content

Commit 1861b49

Browse files
committed
Update utils2.py
1 parent c13ed64 commit 1861b49

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

CIFAR10_code/utils2.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,31 @@ def train(epoch, epochs, model, dataloader, criterion, optimizer, scheduler = No
7171
running_accuracy = 0.0
7272
device = 'cuda' if torch.cuda.is_available() else 'cpu'
7373

74-
74+
model.train()
7575
train_step = len(dataloader)
7676
with tqdm(total=train_step,desc=f'Train Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:
7777
for step,(data, target) in enumerate(dataloader):
7878
data = data.to(device)
7979
target = target.to(device)
80+
#---------------------
81+
# 释放内存
82+
#---------------------
83+
if hasattr(torch.cuda, 'empty_cache'):
84+
torch.cuda.empty_cache()
85+
optimizer.zero_grad()
86+
8087
output = model(data)
8188
loss = criterion(output, target)
82-
83-
optimizer.zero_grad()
84-
loss.backward()
85-
optimizer.step()
89+
8690

8791
acc = get_acc(output,target)
8892
running_accuracy += acc
8993
running_loss += loss.data
90-
91-
lr = optimizer.param_groups[0]['lr']
94+
loss.backward()
95+
optimizer.step()
96+
9297
pbar.set_postfix(**{'Train Acc' : running_accuracy.item()/(step+1),
93-
'Train Loss' :running_loss.item()/(step+1),
94-
'Lr' : lr})
98+
'Train Loss' :running_loss.item()/(step+1)})
9599
pbar.update(1)
96100
if scheduler:
97101
scheduler.step(running_loss)
@@ -122,8 +126,13 @@ def evaluation(epoch, epochs, model, dataloader, criterion):
122126
for step,(data, target) in enumerate(dataloader):
123127
data = data.to(device)
124128
target = target.to(device)
125-
129+
#---------------------
130+
# 释放内存
131+
#---------------------
132+
if hasattr(torch.cuda, 'empty_cache'):
133+
torch.cuda.empty_cache()
126134
output = model(data)
135+
127136
loss = criterion(output, target)
128137
acc = get_acc(output,target)
129138

0 commit comments

Comments
 (0)