Skip to content

Commit 3e7d7b7

Browse files
authored
keras refact (#323)
* keras refact * fix test cases * remove useless config * more kerasify * simplify code * move to keras directory * add graph test * keras_impl to keras * format code
1 parent 5f3f351 commit 3e7d7b7

File tree

2 files changed

+268
-101
lines changed

2 files changed

+268
-101
lines changed
Lines changed: 201 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,206 @@
1-
"""A keras implementation of the efficientdet architecture."""
2-
1+
import functools
32
import tensorflow.compat.v1 as tf
4-
5-
import efficientdet_arch
3+
from tensorflow.python.keras.utils import conv_utils
4+
from efficientdet_arch import nearest_upsampling, build_bifpn_layer
65
import utils
76

87

98
class BiFPNLayer(tf.keras.layers.Layer):
10-
"""A Keras Layer implementing Bidirectional Feature Pyramids"""
11-
12-
def __init__(self,
13-
min_level: int,
14-
max_level: int,
15-
image_size: int,
16-
fpn_weight_method: str,
17-
apply_bn_for_resampling: bool,
18-
is_training_bn: bool,
19-
conv_after_downsample: bool,
20-
use_native_resize_op: bool,
21-
data_format: str,
22-
pooling_type: str,
23-
fpn_num_filters: int,
24-
conv_bn_act_pattern: bool,
25-
act_type: str,
26-
separable_conv: bool,
27-
use_tpu: bool,
28-
fpn_name: str,
29-
**kwargs):
30-
31-
self.min_level = min_level
32-
self.max_level = max_level
33-
self.image_size = image_size
34-
self.feat_sizes = utils.get_feat_sizes(image_size, max_level)
35-
36-
self.fpn_weight_method = fpn_weight_method
37-
self.apply_bn_for_resampling = apply_bn_for_resampling
38-
self.is_training_bn = is_training_bn
39-
self.conv_after_downsample = conv_after_downsample
40-
self.use_native_resize_op = use_native_resize_op
41-
self.data_format = data_format
42-
self.fpn_num_filters = fpn_num_filters
43-
self.pooling_type = pooling_type
44-
self.conv_bn_act_pattern = conv_bn_act_pattern
45-
self.act_type = act_type
46-
self.use_tpu = use_tpu
47-
self.separable_conv = separable_conv
48-
49-
self.fpn_config = None
50-
self.fpn_name = fpn_name
51-
52-
super(BiFPNLayer, self).__init__(**kwargs)
53-
54-
def call(self, feats):
55-
# @TODO: Implement this with keras logic
56-
return efficientdet_arch.build_bifpn_layer(feats, self.feat_sizes, self)
57-
58-
def get_config(self):
59-
base_config = super(BiFPNLayer, self).get_config()
60-
61-
return {
62-
**base_config,
63-
"min_level": self.min_level,
64-
"max_level": self.max_level,
65-
"image_size": self.image_size,
66-
"fpn_name": self.fpn_name,
67-
"fpn_weight_method": self.fpn_weight_method,
68-
"apply_bn_for_resampling": self.apply_bn_for_resampling,
69-
"is_training_bn": self.is_training_bn,
70-
"conv_after_downsample": self.conv_after_downsample,
71-
"use_native_resize_op": self.use_native_resize_op,
72-
"data_format": self.data_format,
73-
"pooling_type": self.pooling_type,
74-
"fpn_num_filters": self.fpn_num_filters,
75-
"conv_bn_act_pattern": self.conv_bn_act_pattern,
76-
"act_type": self.act_type,
77-
"separable_conv": self.separable_conv,
78-
"use_tpu": self.use_tpu,
79-
}
9+
"""A Keras Layer implementing Bidirectional Feature Pyramids"""
10+
11+
def __init__(self,
12+
min_level: int,
13+
max_level: int,
14+
image_size: int,
15+
fpn_weight_method: str,
16+
apply_bn_for_resampling: bool,
17+
is_training_bn: bool,
18+
conv_after_downsample: bool,
19+
use_native_resize_op: bool,
20+
data_format: str,
21+
pooling_type: str,
22+
fpn_num_filters: int,
23+
conv_bn_act_pattern: bool,
24+
act_type: str,
25+
separable_conv: bool,
26+
use_tpu: bool,
27+
fpn_name: str,
28+
**kwargs):
29+
self.min_level = min_level
30+
self.max_level = max_level
31+
self.image_size = image_size
32+
self.feat_sizes = utils.get_feat_sizes(image_size, max_level)
33+
34+
self.fpn_weight_method = fpn_weight_method
35+
self.apply_bn_for_resampling = apply_bn_for_resampling
36+
self.is_training_bn = is_training_bn
37+
self.conv_after_downsample = conv_after_downsample
38+
self.use_native_resize_op = use_native_resize_op
39+
self.data_format = data_format
40+
self.fpn_num_filters = fpn_num_filters
41+
self.pooling_type = pooling_type
42+
self.conv_bn_act_pattern = conv_bn_act_pattern
43+
self.act_type = act_type
44+
self.use_tpu = use_tpu
45+
self.separable_conv = separable_conv
46+
47+
self.fpn_config = None
48+
self.fpn_name = fpn_name
49+
50+
super(BiFPNLayer, self).__init__(**kwargs)
51+
52+
def call(self, feats):
53+
# @TODO: Implement this with keras logic
54+
return build_bifpn_layer(feats, self.feat_sizes, self)
55+
56+
def get_config(self):
57+
base_config = super(BiFPNLayer, self).get_config()
58+
59+
return {
60+
**base_config,
61+
"min_level": self.min_level,
62+
"max_level": self.max_level,
63+
"image_size": self.image_size,
64+
"fpn_name": self.fpn_name,
65+
"fpn_weight_method": self.fpn_weight_method,
66+
"apply_bn_for_resampling": self.apply_bn_for_resampling,
67+
"is_training_bn": self.is_training_bn,
68+
"conv_after_downsample": self.conv_after_downsample,
69+
"use_native_resize_op": self.use_native_resize_op,
70+
"data_format": self.data_format,
71+
"pooling_type": self.pooling_type,
72+
"fpn_num_filters": self.fpn_num_filters,
73+
"conv_bn_act_pattern": self.conv_bn_act_pattern,
74+
"act_type": self.act_type,
75+
"separable_conv": self.separable_conv,
76+
"use_tpu": self.use_tpu,
77+
}
78+
79+
class ResampleFeatureMap(tf.keras.layers.Layer):
80+
def __init__(self,
81+
target_height,
82+
target_width,
83+
target_num_channels,
84+
apply_bn=False,
85+
is_training=None,
86+
conv_after_downsample=False,
87+
use_native_resize_op=False,
88+
pooling_type=None,
89+
use_tpu=False,
90+
data_format=None,
91+
name='resample_feature_map'):
92+
super(ResampleFeatureMap, self).__init__(name='resample_{}'.format(name))
93+
self.apply_bn = apply_bn
94+
self.is_training = is_training
95+
self.data_format = conv_utils.normalize_data_format(data_format)
96+
self.target_num_channels = target_num_channels
97+
self.target_height = target_height
98+
self.target_width = target_width
99+
self.use_tpu = use_tpu
100+
self.conv_after_downsample = conv_after_downsample
101+
self.use_native_resize_op = use_native_resize_op
102+
self.pooling_type = pooling_type
103+
self.conv2d = tf.keras.layers.Conv2D(
104+
self.target_num_channels,
105+
(1, 1),
106+
padding='same',
107+
data_format=self.data_format)
108+
109+
def build(self, input_shape):
110+
"""Resample input feature map to have target number of channels and size."""
111+
if self.data_format == 'channels_first':
112+
_, num_channels, height, width = input_shape.as_list()
113+
else:
114+
_, height, width, num_channels = input_shape.as_list()
115+
116+
if height is None or width is None or num_channels is None:
117+
raise ValueError(
118+
'shape[1] or shape[2] or shape[3] of feat is None (shape:{}).'.format(
119+
input_shape.as_list()))
120+
if self.apply_bn and self.is_training is None:
121+
raise ValueError('If BN is applied, need to provide is_training')
122+
self.num_channels = num_channels
123+
self.height = height
124+
self.width = width
125+
height_stride_size = int((self.height - 1) // self.target_height + 1)
126+
width_stride_size = int((self.width - 1) // self.target_width + 1)
127+
128+
if self.pooling_type == 'max' or self.pooling_type is None:
129+
# Use max pooling in default.
130+
self.pool2d = tf.keras.layers.MaxPooling2D(
131+
pool_size=[height_stride_size + 1, width_stride_size + 1],
132+
strides=[height_stride_size, width_stride_size],
133+
padding='SAME',
134+
data_format=self.data_format)
135+
elif self.pooling_type == 'avg':
136+
self.pool2d = tf.keras.layers.AveragePooling2D(
137+
pool_size=[height_stride_size + 1, width_stride_size + 1],
138+
strides=[height_stride_size, width_stride_size],
139+
padding='SAME',
140+
data_format=self.data_format)
141+
else:
142+
raise ValueError('Unknown pooling type: {}'.format(self.pooling_type))
143+
144+
height_scale = self.target_height // self.height
145+
width_scale = self.target_width // self.width
146+
if (self.use_native_resize_op or self.target_height % self.height != 0 or
147+
self.target_width % self.width != 0):
148+
self.upsample2d = tf.keras.layers.UpSampling2D(
149+
(height_scale, width_scale),
150+
data_format=self.data_format)
151+
else:
152+
self.upsample2d = functools.partial(nearest_upsampling,
153+
height_scale=height_scale,
154+
width_scale=width_scale,
155+
data_format=self.data_format)
156+
super(ResampleFeatureMap, self).build(input_shape)
157+
158+
def _maybe_apply_1x1(self, feat):
159+
"""Apply 1x1 conv to change layer width if necessary."""
160+
if self.num_channels != self.target_num_channels:
161+
feat = self.conv2d(feat)
162+
if self.apply_bn:
163+
feat = utils.batch_norm_act(
164+
feat,
165+
is_training_bn=self.is_training,
166+
act_type=None,
167+
data_format=self.data_format,
168+
use_tpu=self.use_tpu,
169+
name='bn')
170+
return feat
171+
172+
def call(self, feat):
173+
# If conv_after_downsample is True, when downsampling, apply 1x1 after
174+
# downsampling for efficiency.
175+
if self.height > self.target_height and self.width > self.target_width:
176+
if not self.conv_after_downsample:
177+
feat = self._maybe_apply_1x1(feat)
178+
feat = self.pool2d(feat)
179+
if self.conv_after_downsample:
180+
feat = self._maybe_apply_1x1(feat)
181+
elif self.height <= self.target_height and self.width <= self.target_width:
182+
feat = self._maybe_apply_1x1(feat)
183+
if self.height < self.target_height or self.width < self.target_width:
184+
feat = self.upsample2d(feat)
185+
else:
186+
raise ValueError(
187+
'Incompatible target feature map size: target_height: {},'
188+
'target_width: {}'.format(self.target_height, self.target_width))
189+
190+
return feat
191+
192+
def get_config(self):
193+
config = {
194+
'apply_bn': self.apply_bn,
195+
'is_training': self.is_training,
196+
'data_format': self.data_format,
197+
'target_num_channels': self.target_num_channels,
198+
'target_height': self.target_height,
199+
'target_width': self.target_width,
200+
'use_tpu': self.use_tpu,
201+
'conv_after_downsample': self.conv_after_downsample,
202+
'use_native_resize_op': self.use_native_resize_op,
203+
'pooling_type': self.pooling_type,
204+
}
205+
base_config = super(ResampleFeatureMap, self).get_config()
206+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,76 @@
11
import tensorflow.compat.v1 as tf
2-
3-
import keras.efficientdet_arch_keras as arch_keras
2+
from tensorflow.python.framework.test_util import deprecated_graph_mode_only
3+
import efficientdet_arch
4+
from keras import efficientdet_arch_keras
45
import hparams_config
56

6-
77
class KerasBiFPNTest(tf.test.TestCase):
88

9-
def test_BiFPNLayer_get_config(self):
10-
config = hparams_config.get_efficientdet_config()
11-
keras_bifpn = arch_keras.BiFPNLayer(
12-
fpn_name=config.fpn_name,
13-
min_level=config.min_level,
14-
max_level=config.max_level,
15-
fpn_weight_method=config.fpn_weight_method,
16-
apply_bn_for_resampling=config.apply_bn_for_resampling,
17-
is_training_bn=config.is_training_bn,
18-
conv_after_downsample=config.conv_after_downsample,
19-
use_native_resize_op=config.use_native_resize_op,
20-
data_format=config.data_format,
21-
image_size=config.image_size,
22-
fpn_num_filters=config.fpn_num_filters,
23-
conv_bn_act_pattern=config.conv_bn_act_pattern,
24-
act_type=config.act_type,
25-
pooling_type=config.pooling_type,
26-
separable_conv=config.separable_conv,
27-
use_tpu=config.use_tpu
28-
)
9+
def test_BiFPNLayer_get_config(self):
10+
config = hparams_config.get_efficientdet_config()
11+
keras_bifpn = efficientdet_arch_keras.BiFPNLayer(
12+
fpn_name=config.fpn_name,
13+
min_level=config.min_level,
14+
max_level=config.max_level,
15+
fpn_weight_method=config.fpn_weight_method,
16+
apply_bn_for_resampling=config.apply_bn_for_resampling,
17+
is_training_bn=config.is_training_bn,
18+
conv_after_downsample=config.conv_after_downsample,
19+
use_native_resize_op=config.use_native_resize_op,
20+
data_format=config.data_format,
21+
image_size=config.image_size,
22+
fpn_num_filters=config.fpn_num_filters,
23+
conv_bn_act_pattern=config.conv_bn_act_pattern,
24+
act_type=config.act_type,
25+
pooling_type=config.pooling_type,
26+
separable_conv=config.separable_conv,
27+
use_tpu=config.use_tpu
28+
)
29+
30+
layer_config = keras_bifpn.get_config()
31+
new_layer = efficientdet_arch_keras.BiFPNLayer(**layer_config)
32+
self.assertDictEqual(new_layer.get_config(), layer_config)
33+
34+
class KerasTest(tf.test.TestCase):
35+
def test_resample_feature_map(self):
36+
feat = tf.random.uniform([1, 16, 16, 320])
37+
for apply_fn in [True, False]:
38+
for is_training in [True, False]:
39+
for use_tpu in [True, False]:
40+
with self.subTest(apply_fn=apply_fn,
41+
is_training=is_training,
42+
use_tpu=use_tpu):
43+
tf.random.set_random_seed(111111)
44+
expect_result = efficientdet_arch.resample_feature_map(
45+
feat,
46+
name='resample_p0',
47+
target_height=8,
48+
target_width=8,
49+
target_num_channels=64,
50+
apply_bn=apply_fn,
51+
is_training=is_training,
52+
use_tpu=use_tpu)
53+
tf.random.set_random_seed(111111)
54+
actual_result = efficientdet_arch_keras.ResampleFeatureMap(
55+
name='resample_p0',
56+
target_height=8,
57+
target_width=8,
58+
target_num_channels=64,
59+
apply_bn=apply_fn,
60+
is_training=is_training,
61+
use_tpu=use_tpu)(feat)
62+
self.assertAllCloseAccordingToType(expect_result, actual_result)
2963

30-
layer_config = keras_bifpn.get_config()
31-
new_layer = arch_keras.BiFPNLayer(**layer_config)
32-
self.assertDictEqual(new_layer.get_config(), layer_config)
64+
@deprecated_graph_mode_only
65+
def test_name(self):
66+
feat = tf.random.uniform([1, 16, 16, 320])
67+
actual_result = efficientdet_arch_keras.ResampleFeatureMap(
68+
name='p0',
69+
target_height=8,
70+
target_width=8,
71+
target_num_channels=64)(feat)
72+
self.assertEqual("resample_p0/max_pooling2d/MaxPool:0", actual_result.name)
3373

3474

3575
if __name__ == '__main__':
36-
tf.test.main()
76+
tf.test.main()

0 commit comments

Comments
 (0)