Skip to content

Commit c61cb24

Browse files
authored
Update the internal UNet API to return a Keras model (#1286)
* Update the internal UNet API to return a Keras model * Fix centerpillar test
1 parent 80dff46 commit c61cb24

File tree

3 files changed

+63
-22
lines changed

3 files changed

+63
-22
lines changed

keras_cv/models/__internal__/unet.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from tensorflow import keras
1516
from tensorflow.keras import initializers
1617
from tensorflow.keras import layers
18+
from tensorflow.keras import regularizers
1719

1820

1921
def Block(filters, downsample, sync_bn):
2022
"""A default block which serves as an example of the block interface.
2123
2224
This is the base block definition for a CenterPillar model.
25+
26+
Note that the sync_bn parameter is a temporary workaround and should _not_
27+
be part of the Block API.
2328
"""
2429

2530
def apply(x):
@@ -33,7 +38,9 @@ def apply(x):
3338
3,
3439
stride,
3540
padding="same",
41+
use_bias=False,
3642
kernel_initializer=initializers.VarianceScaling(),
43+
kernel_regularizer=regularizers.L2(l2=1e-4),
3744
)(x)
3845
if sync_bn:
3946
x = layers.BatchNormalization(
@@ -48,14 +55,17 @@ def apply(x):
4855
3,
4956
1,
5057
padding="same",
58+
use_bias=False,
5159
kernel_initializer=initializers.VarianceScaling(),
60+
kernel_regularizer=regularizers.L2(l2=1e-4),
5261
)(x)
5362
if sync_bn:
5463
x = layers.BatchNormalization(
5564
synchronized=True,
5665
)(x)
5766
else:
5867
x = layers.BatchNormalization()(x)
68+
x = layers.ReLU()(x)
5969

6070
if downsample:
6171
residual = layers.MaxPool2D(pool_size=2, strides=2, padding="SAME")(
@@ -68,11 +78,19 @@ def apply(x):
6878
1,
6979
1,
7080
padding="same",
81+
use_bias=False,
7182
kernel_initializer=initializers.VarianceScaling(),
83+
kernel_regularizer=regularizers.L2(l2=1e-4),
7284
)(residual)
85+
if sync_bn:
86+
residual = layers.BatchNormalization(
87+
synchronized=True,
88+
)(residual)
89+
else:
90+
residual = layers.BatchNormalization()(residual)
91+
residual = layers.ReLU()(residual)
7392

7493
x = x + residual
75-
x = layers.ReLU()(x)
7694

7795
return x
7896

@@ -85,7 +103,9 @@ def apply(x):
85103
filters,
86104
1,
87105
1,
106+
use_bias=False,
88107
kernel_initializer=initializers.VarianceScaling(),
108+
kernel_regularizer=regularizers.L2(l2=1e-4),
89109
)(x)
90110
if sync_bn:
91111
x = layers.BatchNormalization(
@@ -119,7 +139,9 @@ def apply(x, lateral_input):
119139
3,
120140
2,
121141
padding="same",
142+
use_bias=False,
122143
kernel_initializer=initializers.VarianceScaling(),
144+
kernel_regularizer=regularizers.L2(l2=1e-4),
123145
)(x)
124146
if sync_bn:
125147
x = layers.BatchNormalization(
@@ -129,7 +151,7 @@ def apply(x, lateral_input):
129151
x = layers.BatchNormalization()(x)
130152
x = layers.ReLU()(x)
131153

132-
lateral_input = SkipBlock(filters, sync_bn)(lateral_input)
154+
lateral_input = SkipBlock(filters, sync_bn=sync_bn)(lateral_input)
133155

134156
x = x + lateral_input
135157
x = Block(filters, downsample=False, sync_bn=sync_bn)(x)
@@ -140,6 +162,7 @@ def apply(x, lateral_input):
140162

141163

142164
def UNet(
165+
input_shape,
143166
down_block_configs,
144167
up_block_configs,
145168
down_block=DownSampleBlock,
@@ -155,26 +178,28 @@ def UNet(
155178
function that acts on tensors as inputs.
156179
157180
Args:
181+
input_shape: the rank 3 shape of the input to the UNet
158182
down_block_configs: a list of (filter_count, num_blocks) tuples indicating the
159183
number of filters and sub-blocks in each down block
160184
up_block_configs: a list of filter counts, one for each up block
161185
down_block: a downsampling block
162186
up_block: an upsampling block
163-
sync_bn: True for synchronized batch norm.
164187
"""
165188

166-
def apply(x):
167-
skip_connections = []
168-
# Filters refers to the number of convolutional filters in each block,
169-
# while num_blocks refers to the number of sub-blocks within a block
170-
# (Note that only the first sub-block will perform downsampling)
171-
for filters, num_blocks in down_block_configs:
172-
skip_connections.append(x)
173-
x = down_block(filters, num_blocks, sync_bn)(x)
189+
input = layers.Input(shape=input_shape)
190+
x = input
174191

175-
for filters in up_block_configs:
176-
x = up_block(filters, sync_bn)(x, skip_connections.pop())
192+
skip_connections = []
193+
# Filters refers to the number of convolutional filters in each block,
194+
# while num_blocks refers to the number of sub-blocks within a block
195+
# (Note that only the first sub-block will perform downsampling)
196+
for filters, num_blocks in down_block_configs:
197+
skip_connections.append(x)
198+
x = down_block(filters, num_blocks, sync_bn=sync_bn)(x)
177199

178-
return x
200+
for filters in up_block_configs:
201+
x = up_block(filters, sync_bn=sync_bn)(x, skip_connections.pop())
179202

180-
return apply
203+
output = x
204+
205+
return keras.Model(input, output)

keras_cv/models/__internal__/unet_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,28 @@
1818

1919

2020
class UNetTest(tf.test.TestCase):
21-
# This test is disabled because it requires tf-nightly to run
22-
# (tf-nightly includes the synchronized param for BatchNorm layer)
23-
def test_example_unet_output_shape_bn(self):
21+
def test_example_unet_sync_bn_false(self):
2422
x = tf.random.normal((1, 16, 16, 5))
25-
output = UNet([(128, 6), (256, 2), (512, 1)], [512, 256, 256], sync_bn=False)(x)
23+
model = UNet(
24+
input_shape=(16, 16, 5),
25+
down_block_configs=[(128, 6), (256, 2), (512, 1)],
26+
up_block_configs=[512, 256, 256],
27+
sync_bn=False,
28+
)
29+
output = model(x)
2630
self.assertEqual(output.shape, x.shape[:-1] + (256))
31+
self.assertLen(model.layers, 118)
2732

28-
def disable_test_example_unet_output_shape_sync_bn(self):
33+
# This test is disabled because it requires tf-nightly to run
34+
# (tf-nightly includes the synchronized param for BatchNorm layer)
35+
def disable_test_example_unet_sync_bn_true(self):
2936
x = tf.random.normal((1, 16, 16, 5))
30-
output = UNet([(128, 6), (256, 2), (512, 1)], [512, 256, 256], sync_bn=True)(x)
37+
model = UNet(
38+
input_shape=(16, 16, 5),
39+
down_block_configs=[(128, 6), (256, 2), (512, 1)],
40+
up_block_configs=[512, 256, 256],
41+
sync_bn=True,
42+
)
43+
output = model(x)
3144
self.assertEqual(output.shape, x.shape[:-1] + (256))
45+
self.assertLen(model.layers, 118)

keras_cv/models/object_detection3d/center_pillar_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def build_centerpillar_unet(self, input_shape):
5050
)(x)
5151
x = tf.keras.layers.ReLU()(x)
5252
x = Block(128, downsample=False, sync_bn=False)(x)
53-
output = UNet(down_block_configs, up_block_configs, sync_bn=False)(x)
53+
output = UNet(x.shape[1:], down_block_configs, up_block_configs, sync_bn=False)(
54+
x
55+
)
5456
return tf.keras.Model(input, output)
5557

5658
def test_center_pillar_call(self):

0 commit comments

Comments
 (0)