Skip to content

Commit 9676061

Browse files
authored
Add anchor_generator, box_matcher and non_max_supression (#1849)
* Add anchor_generator, box_matcher and non_max_supression * nit * nit
1 parent 22ce1d5 commit 9676061

File tree

7 files changed

+1336
-0
lines changed

7 files changed

+1336
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasHub Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2024 The KerasHub Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
17+
import keras
18+
from keras import ops
19+
20+
from keras_hub.src.bounding_box.converters import convert_format
21+
22+
23+
class AnchorGenerator(keras.layers.Layer):
24+
"""Generates anchor boxes for object detection tasks.
25+
26+
This layer creates a set of anchor boxes (also known as default boxes or
27+
priors) for use in object detection models, particularly those utilizing
28+
Feature Pyramid Networks (FPN). It generates anchors across multiple
29+
pyramid levels, with various scales and aspect ratios.
30+
31+
Feature Pyramid Levels:
32+
- Levels typically range from 2 to 6 (P2 to P7), corresponding to different
33+
resolutions of the input image.
34+
- Each level l has a stride of 2^l pixels relative to the input image.
35+
- Lower levels (e.g., P2) have higher resolution and are used for
36+
detecting smaller objects.
37+
- Higher levels (e.g., P7) have lower resolution and are used
38+
for larger objects.
39+
40+
Args:
41+
bounding_box_format (str): The format of the bounding boxes
42+
to be generated. Expected to be a string like 'xyxy', 'xywh', etc.
43+
min_level (int): Minimum level of the output feature pyramid.
44+
max_level (int): Maximum level of the output feature pyramid.
45+
num_scales (int): Number of intermediate scales added on each level.
46+
For example, num_scales=2 adds one additional intermediate anchor
47+
scale [2^0, 2^0.5] on each level.
48+
aspect_ratios (list of float): Aspect ratios of anchors added on
49+
each level. Each number indicates the ratio of width to height.
50+
anchor_size (float): Scale of size of the base anchor relative to the
51+
feature stride 2^level.
52+
53+
Call arguments:
54+
images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or
55+
`[H, W, C]`. If provided, its shape will be used to determine anchor
56+
sizes.
57+
58+
Returns:
59+
Dict: A dictionary mapping feature levels
60+
(e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor
61+
of shape `(H/stride * W/stride * num_anchors_per_location, 4)`,
62+
where H and W are the height and width of the image, stride is 2^level,
63+
and num_anchors_per_location is `num_scales * len(aspect_ratios)`.
64+
65+
Example:
66+
```python
67+
anchor_generator = AnchorGenerator(
68+
bounding_box_format='xyxy',
69+
min_level=3,
70+
max_level=7,
71+
num_scales=3,
72+
aspect_ratios=[0.5, 1.0, 2.0],
73+
anchor_size=4.0,
74+
)
75+
anchors = anchor_generator(images=keas.ops.ones(shape=(2, 640, 480, 3)))
76+
```
77+
"""
78+
79+
def __init__(
80+
self,
81+
bounding_box_format,
82+
min_level,
83+
max_level,
84+
num_scales,
85+
aspect_ratios,
86+
anchor_size,
87+
**kwargs,
88+
):
89+
super().__init__(**kwargs)
90+
self.bounding_box_format = bounding_box_format
91+
self.min_level = min_level
92+
self.max_level = max_level
93+
self.num_scales = num_scales
94+
self.aspect_ratios = aspect_ratios
95+
self.anchor_size = anchor_size
96+
self.built = True
97+
98+
def call(self, images):
99+
images_shape = ops.shape(images)
100+
if len(images_shape) == 4:
101+
image_shape = images_shape[1:-1]
102+
else:
103+
image_shape = images_shape[:-1]
104+
105+
image_shape = tuple(image_shape)
106+
107+
multilevel_boxes = {}
108+
for level in range(self.min_level, self.max_level + 1):
109+
boxes_l = []
110+
# Calculate the feature map size for this level
111+
feat_size_y = math.ceil(image_shape[0] / 2**level)
112+
feat_size_x = math.ceil(image_shape[1] / 2**level)
113+
114+
# Calculate the stride (step size) for this level
115+
stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
116+
stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
117+
118+
# Generate anchor center points
119+
# Start from stride/2 to center anchors on pixels
120+
cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
121+
cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
122+
123+
# Create a grid of anchor centers
124+
cx_grid, cy_grid = ops.meshgrid(cx, cy)
125+
126+
for scale in range(self.num_scales):
127+
for aspect_ratio in self.aspect_ratios:
128+
# Calculate the intermediate scale factor
129+
intermidate_scale = 2 ** (scale / self.num_scales)
130+
# Calculate the base anchor size for this level and scale
131+
base_anchor_size = (
132+
self.anchor_size * 2**level * intermidate_scale
133+
)
134+
# Adjust anchor dimensions based on aspect ratio
135+
aspect_x = aspect_ratio**0.5
136+
aspect_y = aspect_ratio**-0.5
137+
half_anchor_size_x = base_anchor_size * aspect_x / 2.0
138+
half_anchor_size_y = base_anchor_size * aspect_y / 2.0
139+
140+
# Generate anchor boxes (y1, x1, y2, x2 format)
141+
boxes = ops.stack(
142+
[
143+
cy_grid - half_anchor_size_y,
144+
cx_grid - half_anchor_size_x,
145+
cy_grid + half_anchor_size_y,
146+
cx_grid + half_anchor_size_x,
147+
],
148+
axis=-1,
149+
)
150+
boxes_l.append(boxes)
151+
# Concat anchors on the same level to tensor shape HxWx(Ax4)
152+
boxes_l = ops.concatenate(boxes_l, axis=-1)
153+
boxes_l = ops.reshape(boxes_l, (-1, 4))
154+
# Convert to user defined
155+
multilevel_boxes[f"P{level}"] = convert_format(
156+
boxes_l,
157+
source="yxyx",
158+
target=self.bounding_box_format,
159+
)
160+
return multilevel_boxes
161+
162+
def compute_output_shape(self, input_shape):
163+
multilevel_boxes_shape = {}
164+
for level in range(self.min_level, self.max_level + 1):
165+
multilevel_boxes_shape[f"P{level}"] = (None, None, 4)
166+
return multilevel_boxes_shape
167+
168+
@property
169+
def anchors_per_location(self):
170+
"""
171+
The `anchors_per_location` property returns the number of anchors
172+
generated per pixel location, which is equal to
173+
`num_scales * len(aspect_ratios)`.
174+
"""
175+
return self.num_scales * len(self.aspect_ratios)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The KerasHub Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import parameterized
16+
from keras import ops
17+
18+
from keras_hub.src.bounding_box.converters import convert_format
19+
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
20+
from keras_hub.src.tests.test_case import TestCase
21+
22+
23+
class AnchorGeneratorTest(TestCase):
24+
@parameterized.parameters(
25+
# Single scale anchor
26+
("yxyx", 5, 5, 1, [1.0], 2.0, [64, 64])
27+
+ (
28+
{
29+
"P5": [
30+
[-16.0, -16.0, 48.0, 48.0],
31+
[-16.0, 16.0, 48.0, 80.0],
32+
[16.0, -16.0, 80.0, 48.0],
33+
[16.0, 16.0, 80.0, 80.0],
34+
]
35+
},
36+
),
37+
# Multi scale anchor
38+
("xywh", 5, 6, 1, [1.0], 2.0, [64, 64])
39+
+ (
40+
{
41+
"P5": [
42+
[-16.0, -16.0, 48.0, 48.0],
43+
[-16.0, 16.0, 48.0, 80.0],
44+
[16.0, -16.0, 80.0, 48.0],
45+
[16.0, 16.0, 80.0, 80.0],
46+
],
47+
"P6": [[-32, -32, 96, 96]],
48+
},
49+
),
50+
# Multi aspect ratio anchor
51+
("xyxy", 6, 6, 1, [1.0, 4.0, 0.25], 2.0, [64, 64])
52+
+ (
53+
{
54+
"P6": [
55+
[-32.0, -32.0, 96.0, 96.0],
56+
[0.0, -96.0, 64.0, 160.0],
57+
[-96.0, 0.0, 160.0, 64.0],
58+
]
59+
},
60+
),
61+
# Intermidate scales
62+
("yxyx", 5, 5, 2, [1.0], 1.0, [32, 32])
63+
+ (
64+
{
65+
"P5": [
66+
[0.0, 0.0, 32.0, 32.0],
67+
[
68+
16 - 16 * 2**0.5,
69+
16 - 16 * 2**0.5,
70+
16 + 16 * 2**0.5,
71+
16 + 16 * 2**0.5,
72+
],
73+
]
74+
},
75+
),
76+
# Non-square
77+
("xywh", 5, 5, 1, [1.0], 1.0, [64, 32])
78+
+ ({"P5": [[0, 0, 32, 32], [32, 0, 64, 32]]},),
79+
# Indivisible by 2^level
80+
("xyxy", 5, 5, 1, [1.0], 1.0, [40, 32])
81+
+ ({"P5": [[-6, 0, 26, 32], [14, 0, 46, 32]]},),
82+
)
83+
def test_anchor_generator(
84+
self,
85+
bounding_box_format,
86+
min_level,
87+
max_level,
88+
num_scales,
89+
aspect_ratios,
90+
anchor_size,
91+
image_shape,
92+
expected_boxes,
93+
):
94+
anchor_generator = AnchorGenerator(
95+
bounding_box_format,
96+
min_level,
97+
max_level,
98+
num_scales,
99+
aspect_ratios,
100+
anchor_size,
101+
)
102+
images = ops.ones(shape=(1, image_shape[0], image_shape[1], 3))
103+
multilevel_boxes = anchor_generator(images=images)
104+
for key in expected_boxes:
105+
expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key])
106+
expected_boxes[key] = convert_format(
107+
expected_boxes[key],
108+
source="yxyx",
109+
target=bounding_box_format,
110+
)
111+
self.assertAllClose(expected_boxes, multilevel_boxes)

0 commit comments

Comments
 (0)