Skip to content

Commit 9dd547a

Browse files
authored
[Port] OD layers to Keras 3 (#2295)
* chore: porting roi aling to keras 3 * chore: fixing the scope, using ones in place of constant * chore: porting roi generation to keras 3 with test note: the nms bit reproduces -1 instead of 0 * chore: port roi pooling * chore: fix pool and port sampler * chore: port label encoder * chore: swap get_shape with ops.shape * lint error * chore: porting sampling to keras 3 * lint fix * chore: using random from backend * chore: disabling flaky test * chore: disable roi sampler test * chore: ignore lint * chore: skipping test the right way * chore: using ops shape * chore: tests pass for all backends removed vectorized map as it was not working for jax and torch used ops convert_to_numpy in tests to make np operations work on torch tensor * chore: explicit type cast to int32
1 parent 9207602 commit 9dd547a

File tree

11 files changed

+584
-583
lines changed

11 files changed

+584
-583
lines changed

keras_cv/layers/object_detection/roi_align.py

Lines changed: 217 additions & 215 deletions
Large diffs are not rendered by default.

keras_cv/layers/object_detection/roi_generator.py

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Mapping
1615
from typing import Optional
17-
from typing import Tuple
18-
from typing import Union
1916

20-
import tensorflow as tf
21-
from tensorflow import keras
22-
23-
from keras_cv import bounding_box
2417
from keras_cv.api_export import keras_cv_export
25-
from keras_cv.backend import assert_tf_keras
18+
from keras_cv.backend import keras
19+
from keras_cv.backend import ops
20+
from keras_cv.layers import NonMaxSuppression
2621

2722

2823
@keras_cv_export("keras_cv.layers.ROIGenerator")
@@ -97,7 +92,6 @@ def __init__(
9792
post_nms_topk_test: int = 1000,
9893
**kwargs,
9994
):
100-
assert_tf_keras("keras_cv.layers.ROIGenerator")
10195
super().__init__(**kwargs)
10296
self.bounding_box_format = bounding_box_format
10397
self.pre_nms_topk_train = pre_nms_topk_train
@@ -112,10 +106,10 @@ def __init__(
112106

113107
def call(
114108
self,
115-
multi_level_boxes: Union[tf.Tensor, Mapping[int, tf.Tensor]],
116-
multi_level_scores: Union[tf.Tensor, Mapping[int, tf.Tensor]],
109+
multi_level_boxes,
110+
multi_level_scores,
117111
training: Optional[bool] = None,
118-
) -> Tuple[tf.Tensor, tf.Tensor]:
112+
):
119113
"""
120114
Args:
121115
multi_level_boxes: float Tensor. A dictionary or single Tensor of
@@ -131,7 +125,6 @@ def call(
131125
rois: float Tensor of [batch_size, post_nms_topk, 4]
132126
roi_scores: float Tensor of [batch_size, post_nms_topk]
133127
"""
134-
135128
if training:
136129
pre_nms_topk = self.pre_nms_topk_train
137130
post_nms_topk = self.post_nms_topk_train
@@ -144,53 +137,35 @@ def call(
144137
nms_iou_threshold = self.nms_iou_threshold_test
145138

146139
def per_level_gen(boxes, scores):
147-
scores_shape = scores.get_shape().as_list()
148-
# scores can also be [batch_size, num_boxes, 1]
140+
boxes = ops.convert_to_tensor(boxes, dtype="float32")
141+
scores = ops.convert_to_tensor(scores, dtype="float32")
142+
scores_shape = ops.shape(scores)
143+
# Check if scores is a 3-dimensional tensor
144+
# ([batch_size, num_boxes, 1])
145+
# If so, remove the last dimension to make it 2D
149146
if len(scores_shape) == 3:
150-
scores = tf.squeeze(scores, axis=-1)
151-
_, num_boxes = scores.get_shape().as_list()
147+
scores = ops.squeeze(scores, axis=-1)
148+
_, num_boxes = scores_shape
152149
level_pre_nms_topk = min(num_boxes, pre_nms_topk)
153150
level_post_nms_topk = min(num_boxes, post_nms_topk)
154-
scores, sorted_indices = tf.nn.top_k(
151+
scores, sorted_indices = ops.top_k(
155152
scores, k=level_pre_nms_topk, sorted=True
156153
)
157-
boxes = tf.gather(boxes, sorted_indices, batch_dims=1)
158-
# convert from input format to yxyx for the TF NMS operation
159-
boxes = bounding_box.convert_format(
160-
boxes,
161-
source=self.bounding_box_format,
162-
target="yxyx",
154+
boxes = ops.take_along_axis(
155+
boxes, sorted_indices[..., None], axis=1
163156
)
164157
# TODO(tanzhenyu): consider supporting soft / batched nms for accl
165-
selected_indices, num_valid = tf.image.non_max_suppression_padded(
166-
boxes,
167-
scores,
168-
max_output_size=level_post_nms_topk,
158+
boxes = NonMaxSuppression(
159+
bounding_box_format=self.bounding_box_format,
160+
from_logits=False,
169161
iou_threshold=nms_iou_threshold,
170-
score_threshold=nms_score_threshold,
171-
pad_to_max_output_size=True,
172-
sorted_input=True,
173-
canonicalized_coordinates=True,
174-
)
175-
# convert back to input format
176-
boxes = bounding_box.convert_format(
177-
boxes,
178-
source="yxyx",
179-
target=self.bounding_box_format,
180-
)
181-
level_rois = tf.gather(boxes, selected_indices, batch_dims=1)
182-
level_roi_scores = tf.gather(scores, selected_indices, batch_dims=1)
183-
level_rois = level_rois * tf.cast(
184-
tf.reshape(tf.range(level_post_nms_topk), [1, -1, 1])
185-
< tf.reshape(num_valid, [-1, 1, 1]),
186-
level_rois.dtype,
187-
)
188-
level_roi_scores = level_roi_scores * tf.cast(
189-
tf.reshape(tf.range(level_post_nms_topk), [1, -1])
190-
< tf.reshape(num_valid, [-1, 1]),
191-
level_roi_scores.dtype,
162+
confidence_threshold=nms_score_threshold,
163+
max_detections=level_post_nms_topk,
164+
)(
165+
box_prediction=boxes,
166+
class_prediction=scores[..., None],
192167
)
193-
return level_rois, level_roi_scores
168+
return boxes["boxes"], boxes["confidence"]
194169

195170
if not isinstance(multi_level_boxes, dict):
196171
return per_level_gen(multi_level_boxes, multi_level_scores)
@@ -204,14 +179,14 @@ def per_level_gen(boxes, scores):
204179
rois.append(level_rois)
205180
roi_scores.append(level_roi_scores)
206181

207-
rois = tf.concat(rois, axis=1)
208-
roi_scores = tf.concat(roi_scores, axis=1)
209-
_, num_valid_rois = roi_scores.get_shape().as_list()
182+
rois = ops.concatenate(rois, axis=1)
183+
roi_scores = ops.concatenate(roi_scores, axis=1)
184+
_, num_valid_rois = ops.shape(roi_scores)
210185
overall_top_k = min(num_valid_rois, post_nms_topk)
211-
roi_scores, sorted_indices = tf.nn.top_k(
186+
roi_scores, sorted_indices = ops.top_k(
212187
roi_scores, k=overall_top_k, sorted=True
213188
)
214-
rois = tf.gather(rois, sorted_indices, batch_dims=1)
189+
rois = ops.take_along_axis(rois, sorted_indices[..., None], axis=1)
215190

216191
return rois, roi_scores
217192

keras_cv/layers/object_detection/roi_generator_test.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import pytest
16-
import tensorflow as tf
1717

1818
from keras_cv.layers.object_detection.roi_generator import ROIGenerator
1919
from keras_cv.tests.test_case import TestCase
@@ -23,7 +23,7 @@
2323
class ROIGeneratorTest(TestCase):
2424
def test_single_tensor(self):
2525
roi_generator = ROIGenerator("xyxy", nms_iou_threshold_train=0.96)
26-
rpn_boxes = tf.constant(
26+
rpn_boxes = np.array(
2727
[
2828
[
2929
[0, 0, 10, 10],
@@ -33,26 +33,33 @@ def test_single_tensor(self):
3333
],
3434
]
3535
)
36-
expected_rois = tf.gather(rpn_boxes, [[1, 3, 2]], batch_dims=1)
37-
expected_rois = tf.concat([expected_rois, tf.zeros([1, 1, 4])], axis=1)
38-
rpn_scores = tf.constant(
36+
indices = [1, 3, 2]
37+
expected_rois = np.take(rpn_boxes, indices, axis=1)
38+
expected_rois = np.concatenate(
39+
[expected_rois, -np.ones([1, 1, 4])], axis=1
40+
)
41+
rpn_scores = np.array(
3942
[
4043
[0.6, 0.9, 0.2, 0.3],
4144
]
4245
)
4346
# selecting the 1st, then 3rd, then 2nd as they don't overlap
4447
# 0th box overlaps with 1st box
45-
expected_roi_scores = tf.gather(rpn_scores, [[1, 3, 2]], batch_dims=1)
46-
expected_roi_scores = tf.concat(
47-
[expected_roi_scores, tf.zeros([1, 1])], axis=1
48+
expected_roi_scores = np.take(rpn_scores, indices, axis=1)
49+
expected_roi_scores = np.concatenate(
50+
[expected_roi_scores, -np.ones([1, 1])], axis=1
51+
)
52+
rois, roi_scores = roi_generator(
53+
multi_level_boxes=rpn_boxes,
54+
multi_level_scores=rpn_scores,
55+
training=True,
4856
)
49-
rois, roi_scores = roi_generator(rpn_boxes, rpn_scores, training=True)
5057
self.assertAllClose(expected_rois, rois)
5158
self.assertAllClose(expected_roi_scores, roi_scores)
5259

5360
def test_single_level_single_batch_roi_ignore_box(self):
5461
roi_generator = ROIGenerator("xyxy", nms_iou_threshold_train=0.96)
55-
rpn_boxes = tf.constant(
62+
rpn_boxes = np.array(
5663
[
5764
[
5865
[0, 0, 10, 10],
@@ -62,19 +69,22 @@ def test_single_level_single_batch_roi_ignore_box(self):
6269
],
6370
]
6471
)
65-
expected_rois = tf.gather(rpn_boxes, [[1, 3, 2]], batch_dims=1)
66-
expected_rois = tf.concat([expected_rois, tf.zeros([1, 1, 4])], axis=1)
72+
indices = [1, 3, 2]
73+
expected_rois = np.take(rpn_boxes, indices, axis=1)
74+
expected_rois = np.concatenate(
75+
[expected_rois, -np.ones([1, 1, 4])], axis=1
76+
)
6777
rpn_boxes = {2: rpn_boxes}
68-
rpn_scores = tf.constant(
78+
rpn_scores = np.array(
6979
[
7080
[0.6, 0.9, 0.2, 0.3],
7181
]
7282
)
7383
# selecting the 1st, then 3rd, then 2nd as they don't overlap
7484
# 0th box overlaps with 1st box
75-
expected_roi_scores = tf.gather(rpn_scores, [[1, 3, 2]], batch_dims=1)
76-
expected_roi_scores = tf.concat(
77-
[expected_roi_scores, tf.zeros([1, 1])], axis=1
85+
expected_roi_scores = np.take(rpn_scores, indices, axis=1)
86+
expected_roi_scores = np.concatenate(
87+
[expected_roi_scores, -np.ones([1, 1])], axis=1
7888
)
7989
rpn_scores = {2: rpn_scores}
8090
rois, roi_scores = roi_generator(rpn_boxes, rpn_scores, training=True)
@@ -85,7 +95,7 @@ def test_single_level_single_batch_roi_all_box(self):
8595
# for iou between 1st and 2nd box is 0.9604, so setting to 0.97 to
8696
# such that NMS would treat them as different ROIs
8797
roi_generator = ROIGenerator("xyxy", nms_iou_threshold_train=0.97)
88-
rpn_boxes = tf.constant(
98+
rpn_boxes = np.array(
8999
[
90100
[
91101
[0, 0, 10, 10],
@@ -95,25 +105,24 @@ def test_single_level_single_batch_roi_all_box(self):
95105
],
96106
]
97107
)
98-
expected_rois = tf.gather(rpn_boxes, [[1, 0, 3, 2]], batch_dims=1)
108+
indices = [1, 0, 3, 2]
109+
expected_rois = np.take(rpn_boxes, indices, axis=1)
99110
rpn_boxes = {2: rpn_boxes}
100-
rpn_scores = tf.constant(
111+
rpn_scores = np.array(
101112
[
102113
[0.6, 0.9, 0.2, 0.3],
103114
]
104115
)
105116
# selecting the 1st, then 0th, then 3rd, then 2nd as they don't overlap
106-
expected_roi_scores = tf.gather(
107-
rpn_scores, [[1, 0, 3, 2]], batch_dims=1
108-
)
117+
expected_roi_scores = np.take(rpn_scores, indices, axis=1)
109118
rpn_scores = {2: rpn_scores}
110119
rois, roi_scores = roi_generator(rpn_boxes, rpn_scores, training=True)
111120
self.assertAllClose(expected_rois, rois)
112121
self.assertAllClose(expected_roi_scores, roi_scores)
113122

114123
def test_single_level_propose_rois(self):
115124
roi_generator = ROIGenerator("xyxy")
116-
rpn_boxes = tf.constant(
125+
rpn_boxes = np.array(
117126
[
118127
[
119128
[0, 0, 10, 10],
@@ -129,21 +138,22 @@ def test_single_level_propose_rois(self):
129138
],
130139
]
131140
)
132-
expected_rois = tf.gather(
133-
rpn_boxes, [[1, 3, 2], [1, 3, 0]], batch_dims=1
141+
indices = np.array([[1, 3, 2], [1, 3, 0]])
142+
expected_rois = np.take_along_axis(
143+
rpn_boxes, indices[:, :, None], axis=1
144+
)
145+
expected_rois = np.concatenate(
146+
[expected_rois, -np.ones([2, 1, 4])], axis=1
134147
)
135-
expected_rois = tf.concat([expected_rois, tf.zeros([2, 1, 4])], axis=1)
136148
rpn_boxes = {2: rpn_boxes}
137-
rpn_scores = tf.constant([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
149+
rpn_scores = np.array([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
138150
# 1st batch -- selecting the 1st, then 3rd, then 2nd as they don't
139151
# overlap
140152
# 2nd batch -- selecting the 1st, then 3rd, then 0th as they don't
141153
# overlap
142-
expected_roi_scores = tf.gather(
143-
rpn_scores, [[1, 3, 2], [1, 3, 0]], batch_dims=1
144-
)
145-
expected_roi_scores = tf.concat(
146-
[expected_roi_scores, tf.zeros([2, 1])], axis=1
154+
expected_roi_scores = np.take_along_axis(rpn_scores, indices, axis=1)
155+
expected_roi_scores = np.concatenate(
156+
[expected_roi_scores, -np.ones([2, 1])], axis=1
147157
)
148158
rpn_scores = {2: rpn_scores}
149159
rois, roi_scores = roi_generator(rpn_boxes, rpn_scores, training=True)
@@ -152,7 +162,7 @@ def test_single_level_propose_rois(self):
152162

153163
def test_two_level_single_batch_propose_rois_ignore_box(self):
154164
roi_generator = ROIGenerator("xyxy")
155-
rpn_boxes = tf.constant(
165+
rpn_boxes = np.array(
156166
[
157167
[
158168
[0, 0, 10, 10],
@@ -168,7 +178,7 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
168178
],
169179
]
170180
)
171-
expected_rois = tf.constant(
181+
expected_rois = np.array(
172182
[
173183
[
174184
[0.1, 0.1, 9.9, 9.9],
@@ -177,13 +187,13 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
177187
[2, 2, 8, 8],
178188
[5, 5, 10, 10],
179189
[2, 2, 4, 4],
180-
[0, 0, 0, 0],
181-
[0, 0, 0, 0],
190+
[-1, -1, -1, -1],
191+
[-1, -1, -1, -1],
182192
]
183193
]
184194
)
185195
rpn_boxes = {2: rpn_boxes[0:1], 3: rpn_boxes[1:2]}
186-
rpn_scores = tf.constant([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
196+
rpn_scores = np.array([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
187197
# 1st batch -- selecting the 1st, then 3rd, then 2nd as they don't
188198
# overlap
189199
# 2nd batch -- selecting the 1st, then 3rd, then 0th as they don't
@@ -196,8 +206,8 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
196206
0.3,
197207
0.2,
198208
0.1,
199-
0.0,
200-
0.0,
209+
-1.0,
210+
-1.0,
201211
]
202212
]
203213
rpn_scores = {2: rpn_scores[0:1], 3: rpn_scores[1:2]}
@@ -207,7 +217,7 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
207217

208218
def test_two_level_single_batch_propose_rois_all_box(self):
209219
roi_generator = ROIGenerator("xyxy", nms_iou_threshold_train=0.99)
210-
rpn_boxes = tf.constant(
220+
rpn_boxes = np.array(
211221
[
212222
[
213223
[0, 0, 10, 10],
@@ -223,7 +233,7 @@ def test_two_level_single_batch_propose_rois_all_box(self):
223233
],
224234
]
225235
)
226-
expected_rois = tf.constant(
236+
expected_rois = np.array(
227237
[
228238
[
229239
[0.1, 0.1, 9.9, 9.9],
@@ -238,7 +248,7 @@ def test_two_level_single_batch_propose_rois_all_box(self):
238248
]
239249
)
240250
rpn_boxes = {2: rpn_boxes[0:1], 3: rpn_boxes[1:2]}
241-
rpn_scores = tf.constant([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
251+
rpn_scores = np.array([[0.6, 0.9, 0.2, 0.3], [0.1, 0.8, 0.3, 0.5]])
242252
# 1st batch -- selecting the 1st, then 0th, then 3rd, then 2nd as they
243253
# don't overlap
244254
# 2nd batch -- selecting the 1st, then 3rd, then 2nd, then 0th as they

0 commit comments

Comments
 (0)