1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from tensorflow import keras
1516from tensorflow .keras import initializers
1617from tensorflow .keras import layers
18+ from tensorflow .keras import regularizers
1719
1820
1921def Block (filters , downsample , sync_bn ):
2022 """A default block which serves as an example of the block interface.
2123
2224 This is the base block definition for a CenterPillar model.
25+
26+ Note that the sync_bn parameter is a temporary workaround and should _not_
27+ be part of the Block API.
2328 """
2429
2530 def apply (x ):
@@ -33,7 +38,9 @@ def apply(x):
3338 3 ,
3439 stride ,
3540 padding = "same" ,
41+ use_bias = False ,
3642 kernel_initializer = initializers .VarianceScaling (),
43+ kernel_regularizer = regularizers .L2 (l2 = 1e-4 ),
3744 )(x )
3845 if sync_bn :
3946 x = layers .BatchNormalization (
@@ -48,14 +55,17 @@ def apply(x):
4855 3 ,
4956 1 ,
5057 padding = "same" ,
58+ use_bias = False ,
5159 kernel_initializer = initializers .VarianceScaling (),
60+ kernel_regularizer = regularizers .L2 (l2 = 1e-4 ),
5261 )(x )
5362 if sync_bn :
5463 x = layers .BatchNormalization (
5564 synchronized = True ,
5665 )(x )
5766 else :
5867 x = layers .BatchNormalization ()(x )
68+ x = layers .ReLU ()(x )
5969
6070 if downsample :
6171 residual = layers .MaxPool2D (pool_size = 2 , strides = 2 , padding = "SAME" )(
@@ -68,11 +78,19 @@ def apply(x):
6878 1 ,
6979 1 ,
7080 padding = "same" ,
81+ use_bias = False ,
7182 kernel_initializer = initializers .VarianceScaling (),
83+ kernel_regularizer = regularizers .L2 (l2 = 1e-4 ),
7284 )(residual )
85+ if sync_bn :
86+ residual = layers .BatchNormalization (
87+ synchronized = True ,
88+ )(residual )
89+ else :
90+ residual = layers .BatchNormalization ()(residual )
91+ residual = layers .ReLU ()(residual )
7392
7493 x = x + residual
75- x = layers .ReLU ()(x )
7694
7795 return x
7896
@@ -85,7 +103,9 @@ def apply(x):
85103 filters ,
86104 1 ,
87105 1 ,
106+ use_bias = False ,
88107 kernel_initializer = initializers .VarianceScaling (),
108+ kernel_regularizer = regularizers .L2 (l2 = 1e-4 ),
89109 )(x )
90110 if sync_bn :
91111 x = layers .BatchNormalization (
@@ -119,7 +139,9 @@ def apply(x, lateral_input):
119139 3 ,
120140 2 ,
121141 padding = "same" ,
142+ use_bias = False ,
122143 kernel_initializer = initializers .VarianceScaling (),
144+ kernel_regularizer = regularizers .L2 (l2 = 1e-4 ),
123145 )(x )
124146 if sync_bn :
125147 x = layers .BatchNormalization (
@@ -129,7 +151,7 @@ def apply(x, lateral_input):
129151 x = layers .BatchNormalization ()(x )
130152 x = layers .ReLU ()(x )
131153
132- lateral_input = SkipBlock (filters , sync_bn )(lateral_input )
154+ lateral_input = SkipBlock (filters , sync_bn = sync_bn )(lateral_input )
133155
134156 x = x + lateral_input
135157 x = Block (filters , downsample = False , sync_bn = sync_bn )(x )
@@ -140,6 +162,7 @@ def apply(x, lateral_input):
140162
141163
142164def UNet (
165+ input_shape ,
143166 down_block_configs ,
144167 up_block_configs ,
145168 down_block = DownSampleBlock ,
@@ -155,26 +178,28 @@ def UNet(
155178 function that acts on tensors as inputs.
156179
157180 Args:
181+ input_shape: the rank 3 shape of the input to the UNet
158182 down_block_configs: a list of (filter_count, num_blocks) tuples indicating the
159183 number of filters and sub-blocks in each down block
160184 up_block_configs: a list of filter counts, one for each up block
161185 down_block: a downsampling block
162186 up_block: an upsampling block
163- sync_bn: True for synchronized batch norm.
164187 """
165188
166- def apply (x ):
167- skip_connections = []
168- # Filters refers to the number of convolutional filters in each block,
169- # while num_blocks refers to the number of sub-blocks within a block
170- # (Note that only the first sub-block will perform downsampling)
171- for filters , num_blocks in down_block_configs :
172- skip_connections .append (x )
173- x = down_block (filters , num_blocks , sync_bn )(x )
189+ input = layers .Input (shape = input_shape )
190+ x = input
174191
175- for filters in up_block_configs :
176- x = up_block (filters , sync_bn )(x , skip_connections .pop ())
192+ skip_connections = []
193+ # Filters refers to the number of convolutional filters in each block,
194+ # while num_blocks refers to the number of sub-blocks within a block
195+ # (Note that only the first sub-block will perform downsampling)
196+ for filters , num_blocks in down_block_configs :
197+ skip_connections .append (x )
198+ x = down_block (filters , num_blocks , sync_bn = sync_bn )(x )
177199
178- return x
200+ for filters in up_block_configs :
201+ x = up_block (filters , sync_bn = sync_bn )(x , skip_connections .pop ())
179202
180- return apply
203+ output = x
204+
205+ return keras .Model (input , output )
0 commit comments