Skip to content

Commit ea3a8f8

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
fix BitMasks indexing
Summary: fix #3591 Differential Revision: D31771140 fbshipit-source-id: cefca0b49e9b061fe45be6bc6d34ed8e1853c7fa
1 parent 8a1589e commit ea3a8f8

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

detectron2/structures/masks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
130130
subject to Pytorch's indexing semantics.
131131
"""
132132
if isinstance(item, int):
133-
return BitMasks(self.tensor[item].view(1, -1))
133+
return BitMasks(self.tensor[item].unsqueeze(0))
134134
m = self.tensor[item]
135135
assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
136136
item, m.shape

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ cloudpickle
1515
Pillow
1616
future
1717
git+git://github.com/facebookresearch/fvcore.git
18-
https://download.pytorch.org/whl/cpu/torch-1.8.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
18+
https://download.pytorch.org/whl/cpu/torch-1.8.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
1919
https://download.pytorch.org/whl/cpu/torchvision-0.9.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
2020
omegaconf>=2.1.0.dev24
2121
hydra-core>=1.1.0.dev5

tests/structures/test_masks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def test_from_empty_polygons(self):
4242
masks = BitMasks.from_polygon_masks([], 100, 100)
4343
self.assertEqual(masks.tensor.shape, (0, 100, 100))
4444

45+
def test_getitem(self):
46+
masks = BitMasks(torch.ones(3, 10, 10))
47+
self.assertEqual(masks[1].tensor.shape, (1, 10, 10))
48+
self.assertEqual(masks[1:3].tensor.shape, (2, 10, 10))
49+
self.assertEqual(masks[torch.tensor([True, False, False])].tensor.shape, (1, 10, 10))
50+
4551

4652
if __name__ == "__main__":
4753
unittest.main()

0 commit comments

Comments
 (0)