Skip to content

Commit c7b2011

Browse files
authored
Update random_sampler.py (#191)
Temporary fix for torch.randperm, from: open-mmlab#5014
1 parent a90acdd commit c7b2011

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

mmdet/core/bbox/samplers/random_sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def random_choice(self, gallery, num):
5151
else:
5252
device = 'cpu'
5353
gallery = torch.tensor(gallery, dtype=torch.long, device=device)
54-
perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
54+
# This is a temporary fix. We can revert the following code
55+
# when PyTorch fixes the abnormal return of torch.randperm.
56+
# See: https://github.com/open-mmlab/mmdetection/pull/5014
57+
perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
5558
rand_inds = gallery[perm]
5659
if not is_tensor:
5760
rand_inds = rand_inds.cpu().numpy()

0 commit comments

Comments
 (0)