@@ -35,7 +35,13 @@ def make_divisible(v, divisor=8, min_value=None):
35
35
36
36
class MobileNetV3 (nn .Layer ):
37
37
def __init__ (
38
- self , in_channels = 3 , model_name = "large" , scale = 0.5 , disable_se = False , ** kwargs
38
+ self ,
39
+ in_channels = 3 ,
40
+ model_name = "large" ,
41
+ scale = 0.5 ,
42
+ disable_se = False ,
43
+ data_format = "NCHW" ,
44
+ ** kwargs ,
39
45
):
40
46
"""
41
47
the MobilenetV3 backbone network for detection module.
@@ -46,6 +52,7 @@ def __init__(
46
52
47
53
self .disable_se = disable_se
48
54
55
+ self .nchw = data_format == "NCHW"
49
56
if model_name == "large" :
50
57
cfg = [
51
58
# k, exp, c, se, nl, s,
@@ -102,6 +109,7 @@ def __init__(
102
109
groups = 1 ,
103
110
if_act = True ,
104
111
act = "hardswish" ,
112
+ data_format = data_format ,
105
113
)
106
114
107
115
self .stages = []
@@ -125,6 +133,7 @@ def __init__(
125
133
stride = s ,
126
134
use_se = se ,
127
135
act = nl ,
136
+ data_format = data_format ,
128
137
)
129
138
)
130
139
inplanes = make_divisible (scale * c )
@@ -139,6 +148,7 @@ def __init__(
139
148
groups = 1 ,
140
149
if_act = True ,
141
150
act = "hardswish" ,
151
+ data_format = data_format ,
142
152
)
143
153
)
144
154
self .stages .append (nn .Sequential (* block_list ))
@@ -147,6 +157,8 @@ def __init__(
147
157
self .add_sublayer (sublayer = stage , name = "stage{}" .format (i ))
148
158
149
159
def forward (self , x ):
160
+ if not self .nchw :
161
+ x = x .transpose ([0 , 2 , 3 , 1 ])
150
162
x = self .conv (x )
151
163
out_list = []
152
164
for stage in self .stages :
@@ -166,6 +178,7 @@ def __init__(
166
178
groups = 1 ,
167
179
if_act = True ,
168
180
act = None ,
181
+ data_format = "NCHW" ,
169
182
):
170
183
super (ConvBNLayer , self ).__init__ ()
171
184
self .if_act = if_act
@@ -178,9 +191,12 @@ def __init__(
178
191
padding = padding ,
179
192
groups = groups ,
180
193
bias_attr = False ,
194
+ data_format = data_format ,
181
195
)
182
196
183
- self .bn = nn .BatchNorm (num_channels = out_channels , act = None )
197
+ self .bn = nn .BatchNorm (
198
+ num_channels = out_channels , act = None , data_layout = data_format
199
+ )
184
200
185
201
def forward (self , x ):
186
202
x = self .conv (x )
@@ -210,6 +226,7 @@ def __init__(
210
226
stride ,
211
227
use_se ,
212
228
act = None ,
229
+ data_format = "NCHW" ,
213
230
):
214
231
super (ResidualUnit , self ).__init__ ()
215
232
self .if_shortcut = stride == 1 and in_channels == out_channels
@@ -223,6 +240,7 @@ def __init__(
223
240
padding = 0 ,
224
241
if_act = True ,
225
242
act = act ,
243
+ data_format = data_format ,
226
244
)
227
245
self .bottleneck_conv = ConvBNLayer (
228
246
in_channels = mid_channels ,
@@ -233,9 +251,10 @@ def __init__(
233
251
groups = mid_channels ,
234
252
if_act = True ,
235
253
act = act ,
254
+ data_format = data_format ,
236
255
)
237
256
if self .if_se :
238
- self .mid_se = SEModule (mid_channels )
257
+ self .mid_se = SEModule (mid_channels , data_format = data_format )
239
258
self .linear_conv = ConvBNLayer (
240
259
in_channels = mid_channels ,
241
260
out_channels = out_channels ,
@@ -244,6 +263,7 @@ def __init__(
244
263
padding = 0 ,
245
264
if_act = False ,
246
265
act = None ,
266
+ data_format = data_format ,
247
267
)
248
268
249
269
def forward (self , inputs ):
@@ -258,22 +278,24 @@ def forward(self, inputs):
258
278
259
279
260
280
class SEModule (nn .Layer ):
261
- def __init__ (self , in_channels , reduction = 4 ):
281
+ def __init__ (self , in_channels , reduction = 4 , data_format = "NCHW" ):
262
282
super (SEModule , self ).__init__ ()
263
- self .avg_pool = nn .AdaptiveAvgPool2D (1 )
283
+ self .avg_pool = nn .AdaptiveAvgPool2D (1 , data_format = data_format )
264
284
self .conv1 = nn .Conv2D (
265
285
in_channels = in_channels ,
266
286
out_channels = in_channels // reduction ,
267
287
kernel_size = 1 ,
268
288
stride = 1 ,
269
289
padding = 0 ,
290
+ data_format = data_format ,
270
291
)
271
292
self .conv2 = nn .Conv2D (
272
293
in_channels = in_channels // reduction ,
273
294
out_channels = in_channels ,
274
295
kernel_size = 1 ,
275
296
stride = 1 ,
276
297
padding = 0 ,
298
+ data_format = data_format ,
277
299
)
278
300
279
301
def forward (self , inputs ):
0 commit comments