@@ -49,14 +49,8 @@ def Conv2D(filters, seed=None, **kwargs): # pylint: disable=invalid-name
4949 return tf .keras .layers .Conv2D (filters , ** default_kwargs )
5050
5151
52- def basic_block (
53- inputs : tf .Tensor ,
54- filters : int ,
55- strides : int ,
56- conv_l2 : float ,
57- bn_l2 : float ,
58- seed : int ,
59- version : int ) -> tf .Tensor :
52+ def basic_block (inputs : tf .Tensor , filters : int , strides : int , conv_l2 : float ,
53+ bn_l2 : float , seed : int , version : int ) -> tf .Tensor :
6054 """Basic residual block of two 3x3 convs.
6155
6256 Args:
@@ -75,30 +69,42 @@ def basic_block(
7569 x = inputs
7670 y = inputs
7771 if version == 2 :
78- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
79- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
72+ y = BatchNormalization (
73+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
74+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
75+ y )
8076 y = tf .keras .layers .Activation ('relu' )(y )
8177 seeds = tf .random .experimental .stateless_split ([seed , seed + 1 ], 3 )[:, 0 ]
82- y = Conv2D (filters ,
83- strides = strides ,
84- seed = seeds [0 ],
85- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(y )
86- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
87- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
78+ y = Conv2D (
79+ filters ,
80+ strides = strides ,
81+ seed = seeds [0 ],
82+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
83+ y )
84+ y = BatchNormalization (
85+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
86+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
87+ y )
8888 y = tf .keras .layers .Activation ('relu' )(y )
89- y = Conv2D (filters ,
90- strides = 1 ,
91- seed = seeds [1 ],
92- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(y )
89+ y = Conv2D (
90+ filters ,
91+ strides = 1 ,
92+ seed = seeds [1 ],
93+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
94+ y )
9395 if version == 1 :
94- y = BatchNormalization (beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
95- gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(y )
96+ y = BatchNormalization (
97+ beta_regularizer = tf .keras .regularizers .l2 (bn_l2 ),
98+ gamma_regularizer = tf .keras .regularizers .l2 (bn_l2 ))(
99+ y )
96100 if not x .shape .is_compatible_with (y .shape ):
97- x = Conv2D (filters ,
98- kernel_size = 1 ,
99- strides = strides ,
100- seed = seeds [2 ],
101- kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(x )
101+ x = Conv2D (
102+ filters ,
103+ kernel_size = 1 ,
104+ strides = strides ,
105+ seed = seeds [2 ],
106+ kernel_regularizer = tf .keras .regularizers .l2 (conv_l2 ))(
107+ x )
102108 x = tf .keras .layers .add ([x , y ])
103109 if version == 1 :
104110 x = tf .keras .layers .Activation ('relu' )(x )
@@ -107,8 +113,8 @@ def basic_block(
107113
108114def group (inputs , filters , strides , num_blocks , conv_l2 , bn_l2 , version , seed ):
109115 """Group of residual blocks."""
110- seeds = tf .random .experimental .stateless_split (
111- [ seed , seed + 1 ], num_blocks )[:, 0 ]
116+ seeds = tf .random .experimental .stateless_split ([ seed , seed + 1 ],
117+ num_blocks )[:, 0 ]
112118 x = basic_block (
113119 inputs ,
114120 filters = filters ,
@@ -187,49 +193,59 @@ def wide_resnet(
187193 raise ValueError ('depth should be 6n+4 (e.g., 16, 22, 28, 40).' )
188194 num_blocks = (depth - 4 ) // 6
189195 inputs = tf .keras .layers .Input (shape = input_shape )
190- x = Conv2D (16 ,
191- strides = 1 ,
192- seed = seeds [0 ],
193- kernel_regularizer = l2_reg (hps ['input_conv_l2' ]))(inputs )
196+ x = Conv2D (
197+ 16 ,
198+ strides = 1 ,
199+ seed = seeds [0 ],
200+ kernel_regularizer = l2_reg (hps ['input_conv_l2' ]))(
201+ inputs )
194202 if version == 1 :
195- x = BatchNormalization (beta_regularizer = l2_reg (hps ['bn_l2' ]),
196- gamma_regularizer = l2_reg (hps ['bn_l2' ]))(x )
203+ x = BatchNormalization (
204+ beta_regularizer = l2_reg (hps ['bn_l2' ]),
205+ gamma_regularizer = l2_reg (hps ['bn_l2' ]))(
206+ x )
197207 x = tf .keras .layers .Activation ('relu' )(x )
198- x = group (x ,
199- filters = 16 * width_multiplier ,
200- strides = 1 ,
201- num_blocks = num_blocks ,
202- conv_l2 = hps ['group_1_conv_l2' ],
203- bn_l2 = hps ['bn_l2' ],
204- version = version ,
205- seed = seeds [1 ])
206- x = group (x ,
207- filters = 32 * width_multiplier ,
208- strides = 2 ,
209- num_blocks = num_blocks ,
210- conv_l2 = hps ['group_2_conv_l2' ],
211- bn_l2 = hps ['bn_l2' ],
212- version = version ,
213- seed = seeds [2 ])
214- x = group (x ,
215- filters = 64 * width_multiplier ,
216- strides = 2 ,
217- num_blocks = num_blocks ,
218- conv_l2 = hps ['group_3_conv_l2' ],
219- bn_l2 = hps ['bn_l2' ],
220- version = version ,
221- seed = seeds [3 ])
208+ x = group (
209+ x ,
210+ filters = 16 * width_multiplier ,
211+ strides = 1 ,
212+ num_blocks = num_blocks ,
213+ conv_l2 = hps ['group_1_conv_l2' ],
214+ bn_l2 = hps ['bn_l2' ],
215+ version = version ,
216+ seed = seeds [1 ])
217+ x = group (
218+ x ,
219+ filters = 32 * width_multiplier ,
220+ strides = 2 ,
221+ num_blocks = num_blocks ,
222+ conv_l2 = hps ['group_2_conv_l2' ],
223+ bn_l2 = hps ['bn_l2' ],
224+ version = version ,
225+ seed = seeds [2 ])
226+ x = group (
227+ x ,
228+ filters = 64 * width_multiplier ,
229+ strides = 2 ,
230+ num_blocks = num_blocks ,
231+ conv_l2 = hps ['group_3_conv_l2' ],
232+ bn_l2 = hps ['bn_l2' ],
233+ version = version ,
234+ seed = seeds [3 ])
222235 if version == 2 :
223- x = BatchNormalization (beta_regularizer = l2_reg (hps ['bn_l2' ]),
224- gamma_regularizer = l2_reg (hps ['bn_l2' ]))(x )
236+ x = BatchNormalization (
237+ beta_regularizer = l2_reg (hps ['bn_l2' ]),
238+ gamma_regularizer = l2_reg (hps ['bn_l2' ]))(
239+ x )
225240 x = tf .keras .layers .Activation ('relu' )(x )
226241 x = tf .keras .layers .AveragePooling2D (pool_size = 8 )(x )
227242 x = tf .keras .layers .Flatten ()(x )
228243 x = tf .keras .layers .Dense (
229244 num_classes ,
230245 kernel_initializer = tf .keras .initializers .HeNormal (seed = seeds [4 ]),
231246 kernel_regularizer = l2_reg (hps ['dense_kernel_l2' ]),
232- bias_regularizer = l2_reg (hps ['dense_bias_l2' ]))(x )
247+ bias_regularizer = l2_reg (hps ['dense_bias_l2' ]))(
248+ x )
233249 return tf .keras .Model (
234250 inputs = inputs ,
235251 outputs = x ,
0 commit comments