|
15 | 15 | MapDataset, |
16 | 16 | ToIterableDataset, |
17 | 17 | build_batch_data_loader, |
| 18 | + build_detection_test_loader, |
18 | 19 | build_detection_train_loader, |
19 | 20 | ) |
20 | 21 | from detectron2.data.samplers import InferenceSampler, TrainingSampler |
@@ -82,25 +83,46 @@ def _get_kwargs(self): |
82 | 83 | kwargs = {k: instantiate(v) for k, v in cfg.items()} |
83 | 84 | return kwargs |
84 | 85 |
|
85 | | - def test_build_dataloader(self): |
| 86 | + def test_build_dataloader_train(self): |
86 | 87 | kwargs = self._get_kwargs() |
87 | 88 | dl = build_detection_train_loader(**kwargs) |
88 | 89 | next(iter(dl)) |
89 | 90 |
|
90 | | - def test_build_iterable_dataloader(self): |
| 91 | + def test_build_iterable_dataloader_train(self): |
91 | 92 | kwargs = self._get_kwargs() |
92 | 93 | ds = DatasetFromList(kwargs.pop("dataset")) |
93 | 94 | ds = ToIterableDataset(ds, TrainingSampler(len(ds))) |
94 | 95 | dl = build_detection_train_loader(dataset=ds, **kwargs) |
95 | 96 | next(iter(dl)) |
96 | 97 |
|
97 | | - def test_build_dataloader_inference(self): |
| 98 | + def _check_is_range(self, data_loader, N): |
| 99 | + # check that data_loader produces range(N) |
| 100 | + data = list(iter(data_loader)) |
| 101 | + data = [x for batch in data for x in batch] # flatten the batches |
| 102 | + self.assertEqual(len(data), N) |
| 103 | + self.assertEqual(set(data), set(range(N))) |
| 104 | + |
| 105 | + def test_build_batch_dataloader_inference(self): |
| 106 | + # Test that build_batch_data_loader can be used for inference |
98 | 107 | N = 96 |
99 | 108 | ds = DatasetFromList(list(range(N))) |
100 | 109 | sampler = InferenceSampler(len(ds)) |
101 | 110 | dl = build_batch_data_loader(ds, sampler, 8, num_workers=3) |
| 111 | + self._check_is_range(dl, N) |
102 | 112 |
|
103 | | - data = list(iter(dl)) |
104 | | - data = [x for batch in data for x in batch] # flatten the batches |
105 | | - self.assertEqual(len(data), N) |
106 | | - self.assertEqual(set(data), set(range(N))) |
| 113 | + def test_build_dataloader_inference(self): |
| 114 | + N = 50 |
| 115 | + ds = DatasetFromList(list(range(N))) |
| 116 | + sampler = InferenceSampler(len(ds)) |
| 117 | + dl = build_detection_test_loader( |
| 118 | + dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3 |
| 119 | + ) |
| 120 | + self._check_is_range(dl, N) |
| 121 | + |
| 122 | + def test_build_iterable_dataloader_inference(self): |
| 123 | + # Test that build_detection_test_loader supports iterable dataset |
| 124 | + N = 50 |
| 125 | + ds = DatasetFromList(list(range(N))) |
| 126 | + ds = ToIterableDataset(ds, InferenceSampler(len(ds))) |
| 127 | + dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3) |
| 128 | + self._check_is_range(dl, N) |
0 commit comments