Skip to content

Commit 8335a84

Browse files
committed
更新参数
1 parent bd2aa57 commit 8335a84

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

CIFAR10_code/eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def eval_top5(outputs, label):
3232
parser.add_argument('--cuda', action='store_true', default=False, help =' use GPU?')
3333
parser.add_argument('--batch-size', default=64, type=int, help = "Batch Size for Test")
3434
parser.add_argument('--num-workers', default=2, type=int, help = 'num-workers')
35-
parser.add_argument('--net', type = str, default='MobileNetv1', help='net type')
35+
parser.add_argument('--net', type = str, choices=['LeNet5', 'AlexNet', 'VGG16','VGG19','ResNet18','ResNet34',
36+
'DenseNet','MobileNetv1','MobileNetv2'], default='MobileNetv1', help='net type')
3637
args = parser.parse_args()
3738
testloader = get_test_dataloader()
3839

CIFAR10_code/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def train(epoch):
135135
'Lr' : lr})
136136
pbar.update(1)
137137
# train_loss = train_loss.item() / len(trainloader)
138-
# train_acc = train_acc.item() * 100 / len(trainloader)
138+
# train_acc = train_acc.item() * 100 / len(trainloader)
139+
scheduler.step(train_loss)
139140
print('Finish Train')
140141
def test(epoch):
141142
global best_acc
@@ -200,6 +201,6 @@ def test(epoch):
200201
for epoch in range(start_epoch, epochs):
201202
train(epoch)
202203
test(epoch)
203-
scheduler.step()
204+
204205
torch.cuda.empty_cache()
205206

0 commit comments

Comments
 (0)