@@ -43,7 +43,7 @@ def test_no_quantize(self):
4343 # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
4444 # --------------------------------------------
4545 expected_feature_map = tf .reshape (
46- tf .constant ([27 , 31 , 59 , 63 ]), [1 , 2 , 2 , 1 ]
46+ tf .constant ([27 , 31 , 59 , 63 ]), [1 , 1 , 2 , 2 , 1 ]
4747 )
4848 self .assertAllClose (expected_feature_map , pooled_feature_map )
4949
@@ -69,7 +69,7 @@ def test_roi_quantize_y(self):
6969 # | 56, 57, 58(max) | 59, 60, 61, 62(max) | 63 (removed)
7070 # --------------------------------------------
7171 expected_feature_map = tf .reshape (
72- tf .constant ([26 , 30 , 58 , 62 ]), [1 , 2 , 2 , 1 ]
72+ tf .constant ([26 , 30 , 58 , 62 ]), [1 , 1 , 2 , 2 , 1 ]
7373 )
7474 self .assertAllClose (expected_feature_map , pooled_feature_map )
7575
@@ -94,7 +94,7 @@ def test_roi_quantize_x(self):
9494 # | 48, 49, 50, 51(max) | 52, 53, 54, 55(max) |
9595 # --------------------------------------------
9696 expected_feature_map = tf .reshape (
97- tf .constant ([19 , 23 , 51 , 55 ]), [1 , 2 , 2 , 1 ]
97+ tf .constant ([19 , 23 , 51 , 55 ]), [1 , 1 , 2 , 2 , 1 ]
9898 )
9999 self .assertAllClose (expected_feature_map , pooled_feature_map )
100100
@@ -121,7 +121,7 @@ def test_roi_quantize_h(self):
121121 # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
122122 # --------------------------------------------
123123 expected_feature_map = tf .reshape (
124- tf .constant ([11 , 15 , 35 , 39 , 59 , 63 ]), [1 , 3 , 2 , 1 ]
124+ tf .constant ([11 , 15 , 35 , 39 , 59 , 63 ]), [1 , 1 , 3 , 2 , 1 ]
125125 )
126126 self .assertAllClose (expected_feature_map , pooled_feature_map )
127127
@@ -147,7 +147,7 @@ def test_roi_quantize_w(self):
147147 # | 56, 57(max) | 58, 59, 60(max) | 61, 62, 63(max) |
148148 # --------------------------------------------
149149 expected_feature_map = tf .reshape (
150- tf .constant ([25 , 28 , 31 , 57 , 60 , 63 ]), [1 , 2 , 3 , 1 ]
150+ tf .constant ([25 , 28 , 31 , 57 , 60 , 63 ]), [1 , 1 , 2 , 3 , 1 ]
151151 )
152152 self .assertAllClose (expected_feature_map , pooled_feature_map )
153153
@@ -168,7 +168,8 @@ def test_roi_feature_map_height_smaller_than_roi(self):
168168 # ------------------repeated----------------------
169169 # | 12, 13(max) | 14, 15(max) |
170170 expected_feature_map = tf .reshape (
171- tf .constant ([1 , 3 , 1 , 3 , 5 , 7 , 9 , 11 , 9 , 11 , 13 , 15 ]), [1 , 6 , 2 , 1 ]
171+ tf .constant ([1 , 3 , 1 , 3 , 5 , 7 , 9 , 11 , 9 , 11 , 13 , 15 ]),
172+ [1 , 1 , 6 , 2 , 1 ],
172173 )
173174 self .assertAllClose (expected_feature_map , pooled_feature_map )
174175
@@ -189,7 +190,7 @@ def test_roi_feature_map_width_smaller_than_roi(self):
189190 # --------------------------------------------
190191 expected_feature_map = tf .reshape (
191192 tf .constant ([4 , 4 , 5 , 6 , 6 , 7 , 12 , 12 , 13 , 14 , 14 , 15 ]),
192- [1 , 2 , 6 , 1 ],
193+ [1 , 1 , 2 , 6 , 1 ],
193194 )
194195 self .assertAllClose (expected_feature_map , pooled_feature_map )
195196
@@ -203,10 +204,43 @@ def test_roi_empty(self):
203204 rois = tf .reshape (tf .constant ([0.0 , 0.0 , 0.0 , 0.0 ]), [1 , 1 , 4 ])
204205 pooled_feature_map = roi_pooler (feature_map , rois )
205206 # all outputs should be top-left pixel
206- self .assertAllClose (tf .ones ([1 , 2 , 2 , 1 ]), pooled_feature_map )
207+ self .assertAllClose (tf .ones ([1 , 1 , 2 , 2 , 1 ]), pooled_feature_map )
207208
208209 def test_invalid_image_shape (self ):
209210 with self .assertRaisesRegex (ValueError , "dynamic shape" ):
210211 _ = ROIPooler (
211212 "rel_yxyx" , target_size = [2 , 2 ], image_shape = [None , 224 , 3 ]
212213 )
214+
215+ def test_multiple_rois (self ):
216+ feature_map = tf .expand_dims (
217+ tf .reshape (tf .range (0 , 64 ), [8 , 8 , 1 ]), axis = 0
218+ )
219+
220+ roi_pooler = ROIPooler (
221+ bounding_box_format = "yxyx" ,
222+ target_size = [2 , 2 ],
223+ image_shape = [224 , 224 , 3 ],
224+ )
225+ rois = tf .constant (
226+ [[[0.0 , 0.0 , 112.0 , 112.0 ], [0.0 , 112.0 , 224.0 , 224.0 ]]],
227+ )
228+
229+ pooled_feature_map = roi_pooler (feature_map , rois )
230+ # the maximum value would be at bottom-right at each block, roi sharded
231+ # into 2x2 blocks
232+ # | 0, 1, 2, 3 | 4, 5, 6, 7 |
233+ # | 8, 9, 10, 11 | 12, 13, 14, 15 |
234+ # | 16, 17, 18, 19 | 20, 21, 22, 23 |
235+ # | 24, 25, 26, 27(max) | 28, 29, 30, 31(max) |
236+ # --------------------------------------------
237+ # | 32, 33, 34, 35 | 36, 37, 38, 39 |
238+ # | 40, 41, 42, 43 | 44, 45, 46, 47 |
239+ # | 48, 49, 50, 51 | 52, 53, 54, 55 |
240+ # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
241+ # --------------------------------------------
242+
243+ expected_feature_map = tf .reshape (
244+ tf .constant ([9 , 11 , 25 , 27 , 29 , 31 , 61 , 63 ]), [1 , 2 , 2 , 2 , 1 ]
245+ )
246+ self .assertAllClose (expected_feature_map , pooled_feature_map )
0 commit comments