1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import numpy as np
1516import pytest
16- import tensorflow as tf
1717
1818from keras_cv .layers .object_detection .roi_generator import ROIGenerator
1919from keras_cv .tests .test_case import TestCase
2323class ROIGeneratorTest (TestCase ):
2424 def test_single_tensor (self ):
2525 roi_generator = ROIGenerator ("xyxy" , nms_iou_threshold_train = 0.96 )
26- rpn_boxes = tf . constant (
26+ rpn_boxes = np . array (
2727 [
2828 [
2929 [0 , 0 , 10 , 10 ],
@@ -33,26 +33,33 @@ def test_single_tensor(self):
3333 ],
3434 ]
3535 )
36- expected_rois = tf .gather (rpn_boxes , [[1 , 3 , 2 ]], batch_dims = 1 )
37- expected_rois = tf .concat ([expected_rois , tf .zeros ([1 , 1 , 4 ])], axis = 1 )
38- rpn_scores = tf .constant (
36+ indices = [1 , 3 , 2 ]
37+ expected_rois = np .take (rpn_boxes , indices , axis = 1 )
38+ expected_rois = np .concatenate (
39+ [expected_rois , - np .ones ([1 , 1 , 4 ])], axis = 1
40+ )
41+ rpn_scores = np .array (
3942 [
4043 [0.6 , 0.9 , 0.2 , 0.3 ],
4144 ]
4245 )
4346 # selecting the 1st, then 3rd, then 2nd as they don't overlap
4447 # 0th box overlaps with 1st box
45- expected_roi_scores = tf .gather (rpn_scores , [[1 , 3 , 2 ]], batch_dims = 1 )
46- expected_roi_scores = tf .concat (
47- [expected_roi_scores , tf .zeros ([1 , 1 ])], axis = 1
48+ expected_roi_scores = np .take (rpn_scores , indices , axis = 1 )
49+ expected_roi_scores = np .concatenate (
50+ [expected_roi_scores , - np .ones ([1 , 1 ])], axis = 1
51+ )
52+ rois , roi_scores = roi_generator (
53+ multi_level_boxes = rpn_boxes ,
54+ multi_level_scores = rpn_scores ,
55+ training = True ,
4856 )
49- rois , roi_scores = roi_generator (rpn_boxes , rpn_scores , training = True )
5057 self .assertAllClose (expected_rois , rois )
5158 self .assertAllClose (expected_roi_scores , roi_scores )
5259
5360 def test_single_level_single_batch_roi_ignore_box (self ):
5461 roi_generator = ROIGenerator ("xyxy" , nms_iou_threshold_train = 0.96 )
55- rpn_boxes = tf . constant (
62+ rpn_boxes = np . array (
5663 [
5764 [
5865 [0 , 0 , 10 , 10 ],
@@ -62,19 +69,22 @@ def test_single_level_single_batch_roi_ignore_box(self):
6269 ],
6370 ]
6471 )
65- expected_rois = tf .gather (rpn_boxes , [[1 , 3 , 2 ]], batch_dims = 1 )
66- expected_rois = tf .concat ([expected_rois , tf .zeros ([1 , 1 , 4 ])], axis = 1 )
72+ indices = [1 , 3 , 2 ]
73+ expected_rois = np .take (rpn_boxes , indices , axis = 1 )
74+ expected_rois = np .concatenate (
75+ [expected_rois , - np .ones ([1 , 1 , 4 ])], axis = 1
76+ )
6777 rpn_boxes = {2 : rpn_boxes }
68- rpn_scores = tf . constant (
78+ rpn_scores = np . array (
6979 [
7080 [0.6 , 0.9 , 0.2 , 0.3 ],
7181 ]
7282 )
7383 # selecting the 1st, then 3rd, then 2nd as they don't overlap
7484 # 0th box overlaps with 1st box
75- expected_roi_scores = tf . gather (rpn_scores , [[ 1 , 3 , 2 ]], batch_dims = 1 )
76- expected_roi_scores = tf . concat (
77- [expected_roi_scores , tf . zeros ([1 , 1 ])], axis = 1
85+ expected_roi_scores = np . take (rpn_scores , indices , axis = 1 )
86+ expected_roi_scores = np . concatenate (
87+ [expected_roi_scores , - np . ones ([1 , 1 ])], axis = 1
7888 )
7989 rpn_scores = {2 : rpn_scores }
8090 rois , roi_scores = roi_generator (rpn_boxes , rpn_scores , training = True )
@@ -85,7 +95,7 @@ def test_single_level_single_batch_roi_all_box(self):
8595 # for iou between 1st and 2nd box is 0.9604, so setting to 0.97 to
8696 # such that NMS would treat them as different ROIs
8797 roi_generator = ROIGenerator ("xyxy" , nms_iou_threshold_train = 0.97 )
88- rpn_boxes = tf . constant (
98+ rpn_boxes = np . array (
8999 [
90100 [
91101 [0 , 0 , 10 , 10 ],
@@ -95,25 +105,24 @@ def test_single_level_single_batch_roi_all_box(self):
95105 ],
96106 ]
97107 )
98- expected_rois = tf .gather (rpn_boxes , [[1 , 0 , 3 , 2 ]], batch_dims = 1 )
108+ indices = [1 , 0 , 3 , 2 ]
109+ expected_rois = np .take (rpn_boxes , indices , axis = 1 )
99110 rpn_boxes = {2 : rpn_boxes }
100- rpn_scores = tf . constant (
111+ rpn_scores = np . array (
101112 [
102113 [0.6 , 0.9 , 0.2 , 0.3 ],
103114 ]
104115 )
105116 # selecting the 1st, then 0th, then 3rd, then 2nd as they don't overlap
106- expected_roi_scores = tf .gather (
107- rpn_scores , [[1 , 0 , 3 , 2 ]], batch_dims = 1
108- )
117+ expected_roi_scores = np .take (rpn_scores , indices , axis = 1 )
109118 rpn_scores = {2 : rpn_scores }
110119 rois , roi_scores = roi_generator (rpn_boxes , rpn_scores , training = True )
111120 self .assertAllClose (expected_rois , rois )
112121 self .assertAllClose (expected_roi_scores , roi_scores )
113122
114123 def test_single_level_propose_rois (self ):
115124 roi_generator = ROIGenerator ("xyxy" )
116- rpn_boxes = tf . constant (
125+ rpn_boxes = np . array (
117126 [
118127 [
119128 [0 , 0 , 10 , 10 ],
@@ -129,21 +138,22 @@ def test_single_level_propose_rois(self):
129138 ],
130139 ]
131140 )
132- expected_rois = tf .gather (
133- rpn_boxes , [[1 , 3 , 2 ], [1 , 3 , 0 ]], batch_dims = 1
141+ indices = np .array ([[1 , 3 , 2 ], [1 , 3 , 0 ]])
142+ expected_rois = np .take_along_axis (
143+ rpn_boxes , indices [:, :, None ], axis = 1
144+ )
145+ expected_rois = np .concatenate (
146+ [expected_rois , - np .ones ([2 , 1 , 4 ])], axis = 1
134147 )
135- expected_rois = tf .concat ([expected_rois , tf .zeros ([2 , 1 , 4 ])], axis = 1 )
136148 rpn_boxes = {2 : rpn_boxes }
137- rpn_scores = tf . constant ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
149+ rpn_scores = np . array ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
138150 # 1st batch -- selecting the 1st, then 3rd, then 2nd as they don't
139151 # overlap
140152 # 2nd batch -- selecting the 1st, then 3rd, then 0th as they don't
141153 # overlap
142- expected_roi_scores = tf .gather (
143- rpn_scores , [[1 , 3 , 2 ], [1 , 3 , 0 ]], batch_dims = 1
144- )
145- expected_roi_scores = tf .concat (
146- [expected_roi_scores , tf .zeros ([2 , 1 ])], axis = 1
154+ expected_roi_scores = np .take_along_axis (rpn_scores , indices , axis = 1 )
155+ expected_roi_scores = np .concatenate (
156+ [expected_roi_scores , - np .ones ([2 , 1 ])], axis = 1
147157 )
148158 rpn_scores = {2 : rpn_scores }
149159 rois , roi_scores = roi_generator (rpn_boxes , rpn_scores , training = True )
@@ -152,7 +162,7 @@ def test_single_level_propose_rois(self):
152162
153163 def test_two_level_single_batch_propose_rois_ignore_box (self ):
154164 roi_generator = ROIGenerator ("xyxy" )
155- rpn_boxes = tf . constant (
165+ rpn_boxes = np . array (
156166 [
157167 [
158168 [0 , 0 , 10 , 10 ],
@@ -168,7 +178,7 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
168178 ],
169179 ]
170180 )
171- expected_rois = tf . constant (
181+ expected_rois = np . array (
172182 [
173183 [
174184 [0.1 , 0.1 , 9.9 , 9.9 ],
@@ -177,13 +187,13 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
177187 [2 , 2 , 8 , 8 ],
178188 [5 , 5 , 10 , 10 ],
179189 [2 , 2 , 4 , 4 ],
180- [0 , 0 , 0 , 0 ],
181- [0 , 0 , 0 , 0 ],
190+ [- 1 , - 1 , - 1 , - 1 ],
191+ [- 1 , - 1 , - 1 , - 1 ],
182192 ]
183193 ]
184194 )
185195 rpn_boxes = {2 : rpn_boxes [0 :1 ], 3 : rpn_boxes [1 :2 ]}
186- rpn_scores = tf . constant ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
196+ rpn_scores = np . array ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
187197 # 1st batch -- selecting the 1st, then 3rd, then 2nd as they don't
188198 # overlap
189199 # 2nd batch -- selecting the 1st, then 3rd, then 0th as they don't
@@ -196,8 +206,8 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
196206 0.3 ,
197207 0.2 ,
198208 0.1 ,
199- 0 .0 ,
200- 0 .0 ,
209+ - 1 .0 ,
210+ - 1 .0 ,
201211 ]
202212 ]
203213 rpn_scores = {2 : rpn_scores [0 :1 ], 3 : rpn_scores [1 :2 ]}
@@ -207,7 +217,7 @@ def test_two_level_single_batch_propose_rois_ignore_box(self):
207217
208218 def test_two_level_single_batch_propose_rois_all_box (self ):
209219 roi_generator = ROIGenerator ("xyxy" , nms_iou_threshold_train = 0.99 )
210- rpn_boxes = tf . constant (
220+ rpn_boxes = np . array (
211221 [
212222 [
213223 [0 , 0 , 10 , 10 ],
@@ -223,7 +233,7 @@ def test_two_level_single_batch_propose_rois_all_box(self):
223233 ],
224234 ]
225235 )
226- expected_rois = tf . constant (
236+ expected_rois = np . array (
227237 [
228238 [
229239 [0.1 , 0.1 , 9.9 , 9.9 ],
@@ -238,7 +248,7 @@ def test_two_level_single_batch_propose_rois_all_box(self):
238248 ]
239249 )
240250 rpn_boxes = {2 : rpn_boxes [0 :1 ], 3 : rpn_boxes [1 :2 ]}
241- rpn_scores = tf . constant ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
251+ rpn_scores = np . array ([[0.6 , 0.9 , 0.2 , 0.3 ], [0.1 , 0.8 , 0.3 , 0.5 ]])
242252 # 1st batch -- selecting the 1st, then 0th, then 3rd, then 2nd as they
243253 # don't overlap
244254 # 2nd batch -- selecting the 1st, then 3rd, then 2nd, then 0th as they
0 commit comments