@@ -268,61 +268,56 @@ def get_vgg(model_name, im_size, channel, num_classes, depth=11, batchnorm=False
268268 if use_torchvision :
269269 if depth == 11 :
270270 if batchnorm :
271- return torchvision .models .vgg11_bn (num_classes = num_classes , pretrained = pretrained )
271+ return torchvision .models .vgg11_bn (num_classes = num_classes , pretrained = False )
272272 else :
273- return torchvision .models .vgg11 (num_classes = num_classes , pretrained = pretrained )
273+ return torchvision .models .vgg11 (num_classes = num_classes , pretrained = False )
274274 elif depth == 13 :
275275 if batchnorm :
276- return torchvision .models .vgg13_bn (num_classes = num_classes , pretrained = pretrained )
276+ return torchvision .models .vgg13_bn (num_classes = num_classes , pretrained = False )
277277 else :
278- return torchvision .models .vgg13 (num_classes = num_classes , pretrained = pretrained )
278+ return torchvision .models .vgg13 (num_classes = num_classes , pretrained = False )
279279 elif depth == 16 :
280280 if batchnorm :
281- return torchvision .models .vgg16_bn (num_classes = num_classes , pretrained = pretrained )
281+ return torchvision .models .vgg16_bn (num_classes = num_classes , pretrained = False )
282282 else :
283- return torchvision .models .vgg16 (num_classes = num_classes , pretrained = pretrained )
283+ return torchvision .models .vgg16 (num_classes = num_classes , pretrained = False )
284284 elif depth == 19 :
285285 if batchnorm :
286- return torchvision .models .vgg19_bn (num_classes = num_classes , pretrained = pretrained )
286+ return torchvision .models .vgg19_bn (num_classes = num_classes , pretrained = False )
287287 else :
288- return torchvision .models .vgg19 (num_classes = num_classes , pretrained = pretrained )
288+ return torchvision .models .vgg19 (num_classes = num_classes , pretrained = False )
289289 else :
290290 model = VGG (f'VGG{ depth } ' , channel , num_classes , norm = 'batchnorm' if batchnorm else 'instancenorm' , res = im_size [0 ])
291- if pretrained :
292- model . load_state_dict ( torch . load ( model_path , map_location = 'cpu' , weights_only = True ))
293-
294- return model
291+
292+ if pretrained :
293+ model . load_state_dict ( torch . load ( model_path , map_location = 'cpu' , weights_only = True ))
294+ return model
295295
296296
297297def get_resnet (model_name , im_size , channel , num_classes , depth = 18 , batchnorm = False , use_torchvision = False , pretrained = False , model_path = None ):
298298 print (f"Creating { model_name } with channel={ channel } , num_classes={ num_classes } " )
299299 if use_torchvision :
300+ print (f"ResNet in torchvision uses batchnorm by default." )
300301 if depth == 18 :
301- if batchnorm :
302- return torchvision .models .resnet18_bn (num_classes = num_classes , pretrained = pretrained )
303- else :
304- return torchvision .models .resnet18 (num_classes = num_classes , pretrained = pretrained )
302+ model = torchvision .models .resnet18 (num_classes = num_classes , pretrained = False )
305303 elif depth == 34 :
306- if batchnorm :
307- return torchvision .models .resnet34_bn (num_classes = num_classes , pretrained = pretrained )
308- else :
309- return torchvision .models .resnet34 (num_classes = num_classes , pretrained = pretrained )
304+ model = torchvision .models .resnet34 (num_classes = num_classes , pretrained = False )
310305 elif depth == 50 :
311- if batchnorm :
312- return torchvision . models . resnet50_bn ( num_classes = num_classes , pretrained = pretrained )
313- else :
314- return torchvision . models . resnet50 ( num_classes = num_classes , pretrained = pretrained )
306+ model = torchvision . models . resnet50 ( num_classes = num_classes , pretrained = False )
307+ if im_size == ( 64 , 64 ):
308+ model . conv1 = torch . nn . Conv2d ( 3 , 64 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 ), bias = False )
309+ model . maxpool = torch . nn . Identity ( )
315310 else :
316311 if depth == 18 :
317312 model = ResNet (BasicBlock , [2 ,2 ,2 ,2 ], channel = channel , num_classes = num_classes , norm = 'batchnorm' if batchnorm else 'instancenorm' , res = im_size [0 ])
318313 elif depth == 34 :
319314 model = ResNet (BasicBlock , [3 ,4 ,6 ,3 ], channel = channel , num_classes = num_classes , norm = 'batchnorm' if batchnorm else 'instancenorm' , res = im_size [0 ])
320315 elif depth == 50 :
321316 model = ResNet (Bottleneck , [3 ,4 ,6 ,3 ], channel = channel , num_classes = num_classes , norm = 'batchnorm' if batchnorm else 'instancenorm' , res = im_size [0 ])
322- if pretrained :
323- model . load_state_dict ( torch . load ( model_path , map_location = 'cpu' , weights_only = True ))
324-
325- return model
317+
318+ if pretrained :
319+ model . load_state_dict ( torch . load ( model_path , map_location = 'cpu' , weights_only = True ))
320+ return model
326321
327322
328323def get_other_models (model_name , channel , num_classes , im_size = (32 , 32 ), pretrained = False , model_path = None ):
0 commit comments