Skip to content

Commit 757083d

Browse files
author
Beat Buesser
committed
Update channel transform
Signed-off-by: Beat Buesser <[email protected]>
1 parent 66ce420 commit 757083d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

art/estimators/object_detection/python_object_detector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,10 @@ def loss_gradient( # pylint: disable=W0613
308308

309309
if isinstance(x, np.ndarray):
310310
grads = np.stack(grad_list, axis=0)
311+
grads = np.transpose(grads, (0, 2, 3, 1))
311312
else:
312313
grads = torch.stack(grad_list, dim=0)
313-
grads = np.transpose(grads, (0, 2, 3, 1))
314+
grads = grads.premute(0, 2, 3, 1)
314315

315316
if self.clip_values is not None:
316317
grads = grads / self.clip_values[1]

0 commit comments

Comments
 (0)