11# pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html
22from __future__ import print_function , division
33
4+ import itertools
45import torch
56import torch .nn as nn
67import torchvision .models as models
2021# 'mnasnet': models.mnasnet1_0
2122# }
2223
23- class ResNet18 (nn .Module ):
24+ class BuiltInNet (nn .Module ):
25+ """
26+ Built-in Network in Pytorch for classification.
27+ Parameters should be set in the `params` dictionary that contains the
28+ following fields:
29+
30+ :param input_chns: (int) Input channel number, default is 3.
31+ :param pretrain: (bool) Using pretrained model or not, default is True.
32+ :param update_mode: (str) The strategy for updating layers: "`all`" means updating
33+ all the layers, and "`last`" (by default) means updating the last layer,
34+ as well as the first layer when `input_chns` is not 3.
35+ """
2436 def __init__ (self , params ):
25- super (ResNet18 , self ).__init__ ()
26- self .params = params
27- cls_num = params ['class_num' ]
28- in_chns = params .get ('input_chns' , 3 )
37+ super (BuiltInNet , self ).__init__ ()
38+ self .params = params
39+ self .in_chns = params .get ('input_chns' , 3 )
2940 self .pretrain = params .get ('pretrain' , True )
30- self .update_layers = params .get ('update_layers' , 0 )
41+ self .update_mode = params .get ('update_mode' , "last" )
42+ self .net = None
43+
44+ def forward (self , x ):
45+ return self .net (x )
46+
47+ def get_parameters_to_update (self ):
48+ pass
49+
50+ class ResNet18 (BuiltInNet ):
51+ """
52+ ResNet18 for classification.
53+ Parameters should be set in the `params` dictionary that contains the
54+ following fields:
55+
56+ :param input_chns: (int) Input channel number, default is 3.
57+ :param pretrain: (bool) Using pretrained model or not, default is True.
58+ :param update_mode: (str) The strategy for updating layers: "`all`" means updating
59+ all the layers, and "`last`" (by default) means updating the last layer,
60+ as well as the first layer when `input_chns` is not 3.
61+ """
62+ def __init__ (self , params ):
63+ super (ResNet18 , self ).__init__ (params )
3164 self .net = models .resnet18 (pretrained = self .pretrain )
3265
3366 # replace the last layer
3467 num_ftrs = self .net .fc .in_features
35- self .net .fc = nn .Linear (num_ftrs , cls_num )
36-
37- def forward (self , x ):
38- return self .net (x )
68+ self .net .fc = nn .Linear (num_ftrs , params ['class_num' ])
69+
70+ # replace the first layer when in_chns is not 3
71+ if (self .in_chns != 3 ):
72+ self .net .conv1 = nn .Conv2d (self .in_chns , 64 , kernel_size = (7 , 7 ),
73+ stride = (2 , 2 ), padding = (3 , 3 ), bias = False )
3974
4075 def get_parameters_to_update (self ):
41- if (self .pretrain == False or self . update_layers == 0 ):
76+ if (self .update_mode == "all" ):
4277 return self .net .parameters ()
43- elif (self .update_layers == - 1 ):
44- return self .net .fc .parameters ()
78+ elif (self .update_layers == "last" ):
79+ params = self .net .fc .parameters ()
80+ if (self .in_chns != 3 ):
81+ # combining the two iterables into a single one
82+ # see: https://dzone.com/articles/python-joining-multiple
83+ params = itertools .chain ()
84+ for pram in [self .net .fc .parameters (), self .net .conv1 .parameters ()]:
85+ params = itertools .chain (params , pram )
86+ return params
4587 else :
46- raise (ValueError ("update_layers can only be 0 (all layers) " +
47- "or -1 (the last layer)" ))
88+ raise (ValueError ("update_mode can only be 'all' or 'last'." ))
4889
49- class VGG16 (nn .Module ):
90+ class VGG16 (BuiltInNet ):
91+ """
92+ VGG16 for classification.
93+ Parameters should be set in the `params` dictionary that contains the
94+ following fields:
95+
96+ :param input_chns: (int) Input channel number, default is 3.
97+ :param pretrain: (bool) Using pretrained model or not, default is True.
98+ :param update_mode: (str) The strategy for updating layers: "`all`" means updating
99+ all the layers, and "`last`" (by default) means updating the last layer,
100+ as well as the first layer when `input_chns` is not 3.
101+ """
50102 def __init__ (self , params ):
51- super (VGG16 , self ).__init__ ()
52- self .params = params
53- cls_num = params ['class_num' ]
54- in_chns = params .get ('input_chns' , 3 )
55- self .pretrain = params .get ('pretrain' , True )
56- self .update_layers = params .get ('update_layers' , 0 )
103+ super (VGG16 , self ).__init__ (params )
57104 self .net = models .vgg16 (pretrained = self .pretrain )
58105
59106 # replace the last layer
60107 num_ftrs = self .net .classifier [- 1 ].in_features
61- self .net .classifier [- 1 ] = nn .Linear (num_ftrs , cls_num )
62-
63- def forward (self , x ):
64- return self .net (x )
108+ self .net .classifier [- 1 ] = nn .Linear (num_ftrs , params ['class_num' ])
109+
110+ # replace the first layer when in_chns is not 3
111+ if (self .in_chns != 3 ):
112+ self .net .features [0 ] = nn .Conv2d (self .in_chns , 64 , kernel_size = (3 , 3 ),
113+ stride = (1 , 1 ), padding = (1 , 1 ), bias = False )
65114
66115 def get_parameters_to_update (self ):
67- if (self .pretrain == False or self . update_layers == 0 ):
116+ if (self .update_mode == "all" ):
68117 return self .net .parameters ()
69- elif (self .update_layers == - 1 ):
70- return self .net .classifier [- 1 ].parameters ()
118+ elif (self .update_mode == "last" ):
119+ params = self .net .classifier [- 1 ].parameters ()
120+ if (self .in_chns != 3 ):
121+ params = itertools .chain ()
122+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .net .features [0 ].parameters ()]:
123+ params = itertools .chain (params , pram )
124+ return params
71125 else :
72- raise (ValueError ("update_layers can only be 0 (all layers) " +
73- "or -1 (the last layer)" ))
126+ raise (ValueError ("update_mode can only be 'all' or 'last'." ))
127+
128+ class MobileNetV2 (BuiltInNet ):
129+ """
130+ MobileNetV2 for classification.
131+ Parameters should be set in the `params` dictionary that contains the
132+ following fields:
74133
75- class MobileNetV2 (nn .Module ):
134+ :param input_chns: (int) Input channel number, default is 3.
135+ :param pretrain: (bool) Using pretrained model or not, default is True.
136+ :param update_mode: (str) The strategy for updating layers: "`all`" means updating
137+ all the layers, and "`last`" (by default) means updating the last layer,
138+ as well as the first layer when `input_chns` is not 3.
139+ """
76140 def __init__ (self , params ):
77141 super (MobileNetV2 , self ).__init__ ()
78- self .params = params
79- cls_num = params ['class_num' ]
80- in_chns = params .get ('input_chns' , 3 )
81- self .pretrain = params .get ('pretrain' , True )
82- self .update_layers = params .get ('update_layers' , 0 )
83142 self .net = models .mobilenet_v2 (pretrained = self .pretrain )
84143
85144 # replace the last layer
86145 num_ftrs = self .net .last_channel
87- self .net .classifier [- 1 ] = nn .Linear (num_ftrs , cls_num )
88-
89- def forward (self , x ):
90- return self .net (x )
146+ self .net .classifier [- 1 ] = nn .Linear (num_ftrs , params ['class_num' ])
147+
148+ # replace the first layer when in_chns is not 3
149+ if (self .in_chns != 3 ):
150+ self .net .features [0 ][0 ] = nn .Conv2d (self .in_chns , 32 , kernel_size = (3 , 3 ),
151+ stride = (2 , 2 ), padding = (1 , 1 ), bias = False )
91152
92153 def get_parameters_to_update (self ):
93- if (self .pretrain == False or self . update_layers == 0 ):
154+ if (self .update_mode == "all" ):
94155 return self .net .parameters ()
95- elif (self .update_layers == - 1 ):
96- return self .net .classifier [- 1 ].parameters ()
156+ elif (self .update_mode == "last" ):
157+ params = self .net .classifier [- 1 ].parameters ()
158+ if (self .in_chns != 3 ):
159+ params = itertools .chain ()
160+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .net .features [0 ][0 ].parameters ()]:
161+ params = itertools .chain (params , pram )
162+ return params
97163 else :
98- raise (ValueError ("update_layers can only be 0 (all layers) " +
99- "or -1 (the last layer)" ))
164+ raise (ValueError ("update_mode can only be 'all' or 'last'." ))
165+
166+ if __name__ == "__main__" :
167+ params = {"class_num" : 2 , "pretrain" : False , "input_chns" : 3 }
168+ net = ResNet18 (params )
169+ print (net )
0 commit comments