Skip to content

Commit 42f029c

Browse files
Update models
1 parent ef679ae commit 42f029c

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

dd_ranking/utils/utils.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -268,61 +268,56 @@ def get_vgg(model_name, im_size, channel, num_classes, depth=11, batchnorm=False
268268
if use_torchvision:
269269
if depth == 11:
270270
if batchnorm:
271-
return torchvision.models.vgg11_bn(num_classes=num_classes, pretrained=pretrained)
271+
return torchvision.models.vgg11_bn(num_classes=num_classes, pretrained=False)
272272
else:
273-
return torchvision.models.vgg11(num_classes=num_classes, pretrained=pretrained)
273+
return torchvision.models.vgg11(num_classes=num_classes, pretrained=False)
274274
elif depth == 13:
275275
if batchnorm:
276-
return torchvision.models.vgg13_bn(num_classes=num_classes, pretrained=pretrained)
276+
return torchvision.models.vgg13_bn(num_classes=num_classes, pretrained=False)
277277
else:
278-
return torchvision.models.vgg13(num_classes=num_classes, pretrained=pretrained)
278+
return torchvision.models.vgg13(num_classes=num_classes, pretrained=False)
279279
elif depth == 16:
280280
if batchnorm:
281-
return torchvision.models.vgg16_bn(num_classes=num_classes, pretrained=pretrained)
281+
return torchvision.models.vgg16_bn(num_classes=num_classes, pretrained=False)
282282
else:
283-
return torchvision.models.vgg16(num_classes=num_classes, pretrained=pretrained)
283+
return torchvision.models.vgg16(num_classes=num_classes, pretrained=False)
284284
elif depth == 19:
285285
if batchnorm:
286-
return torchvision.models.vgg19_bn(num_classes=num_classes, pretrained=pretrained)
286+
return torchvision.models.vgg19_bn(num_classes=num_classes, pretrained=False)
287287
else:
288-
return torchvision.models.vgg19(num_classes=num_classes, pretrained=pretrained)
288+
return torchvision.models.vgg19(num_classes=num_classes, pretrained=False)
289289
else:
290290
model = VGG(f'VGG{depth}', channel, num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0])
291-
if pretrained:
292-
model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
293-
294-
return model
291+
292+
if pretrained:
293+
model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
294+
return model
295295

296296

297297
def get_resnet(model_name, im_size, channel, num_classes, depth=18, batchnorm=False, use_torchvision=False, pretrained=False, model_path=None):
298298
print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}")
299299
if use_torchvision:
300+
print(f"ResNet in torchvision uses batchnorm by default.")
300301
if depth == 18:
301-
if batchnorm:
302-
return torchvision.models.resnet18_bn(num_classes=num_classes, pretrained=pretrained)
303-
else:
304-
return torchvision.models.resnet18(num_classes=num_classes, pretrained=pretrained)
302+
model = torchvision.models.resnet18(num_classes=num_classes, pretrained=False)
305303
elif depth == 34:
306-
if batchnorm:
307-
return torchvision.models.resnet34_bn(num_classes=num_classes, pretrained=pretrained)
308-
else:
309-
return torchvision.models.resnet34(num_classes=num_classes, pretrained=pretrained)
304+
model = torchvision.models.resnet34(num_classes=num_classes, pretrained=False)
310305
elif depth == 50:
311-
if batchnorm:
312-
return torchvision.models.resnet50_bn(num_classes=num_classes, pretrained=pretrained)
313-
else:
314-
return torchvision.models.resnet50(num_classes=num_classes, pretrained=pretrained)
306+
model = torchvision.models.resnet50(num_classes=num_classes, pretrained=False)
307+
if im_size == (64, 64):
308+
model.conv1 = torch.nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
309+
model.maxpool = torch.nn.Identity()
315310
else:
316311
if depth == 18:
317312
model = ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0])
318313
elif depth == 34:
319314
model = ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0])
320315
elif depth == 50:
321316
model = ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0])
322-
if pretrained:
323-
model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
324-
325-
return model
317+
318+
if pretrained:
319+
model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
320+
return model
326321

327322

328323
def get_other_models(model_name, channel, num_classes, im_size=(32, 32), pretrained=False, model_path=None):

0 commit comments

Comments
 (0)