@@ -137,35 +137,32 @@ def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs):
137137
138138 # TODO(ianstenbit): loss heatmap threshold should be configurable.
139139 box_regression_mask = (
140- ops .squeeze (
141- ops .take (
142- ops .reshape (heatmap , (ops .shape (heatmap )[0 ], - 1 )),
143- index [..., 0 ] * ops .shape (heatmap )[1 ] + index [..., 1 ],
144- axis = 1 ,
145- ),
146- axis = 0 ,
140+ ops .take_along_axis (
141+ ops .reshape (heatmap , (heatmap .shape [0 ], - 1 )),
142+ index [..., 0 ] * heatmap .shape [1 ] + index [..., 1 ],
143+ axis = 1 ,
147144 )
148145 > 0.95
149146 )
150147
151- box = ops .squeeze (
152- ops .take (
153- ops .reshape (box , (ops .shape (box )[0 ], - 1 , 7 )),
154- index [..., 0 ] * ops .shape (box )[1 ] + index [..., 1 ],
155- axis = 1 ,
148+ box = ops .take_along_axis (
149+ ops .reshape (box , (ops .shape (box )[0 ], - 1 , 7 )),
150+ ops .expand_dims (
151+ index [..., 0 ] * ops .shape (box )[1 ] + index [..., 1 ], axis = - 1
156152 ),
157- axis = 0 ,
153+ axis = 1 ,
158154 )
159- box_pred = ops .squeeze (
160- ops .take (
161- ops .reshape (
162- box_pred ,
163- (ops .shape (box_pred )[0 ], - 1 , ops .shape (box_pred )[- 1 ]),
164- ),
155+
156+ box_pred = ops .take_along_axis (
157+ ops .reshape (
158+ box_pred ,
159+ (ops .shape (box_pred )[0 ], - 1 , ops .shape (box_pred )[- 1 ]),
160+ ),
161+ ops .expand_dims (
165162 index [..., 0 ] * ops .shape (box_pred )[1 ] + index [..., 1 ],
166- axis = 1 ,
163+ axis = - 1 ,
167164 ),
168- axis = 0 ,
165+ axis = 1 ,
169166 )
170167
171168 box_center_mask = heatmap > 0.99
0 commit comments