Skip to content

Commit a5f2845

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
support registering IterableDataset and build it from config.
Summary: Fix #4105 Pull Request resolved: #4180 Reviewed By: zhanghang1989 Differential Revision: D36045130 fbshipit-source-id: f2aa3bfdebf476737b6deec8eac93ef3043964b8
1 parent 0ad20f1 commit a5f2845

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

INSTALL.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ C++ compilation errors from NVCC / NVRTC, or "Unsupported gpu architecture"
187187
<br/>
188188
A few possibilities:
189189

190-
1. Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in `python collect_env.py`.
190+
1. Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in `python collect_env.py`
191+
(download from [here](./detectron2/utils/collect_env.py)).
191192
When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself)
192193
to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
193194

detectron2/data/build.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ def get_detection_dataset_dicts(
239239
names = [names]
240240
assert len(names), names
241241
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
242+
243+
if isinstance(dataset_dicts[0], torchdata.Dataset):
244+
if len(dataset_dicts) > 1:
245+
# ConcatDataset does not work for iterable style dataset.
246+
# We could support concat for iterable as well, but it's often
247+
# not a good idea to concat iterables anyway.
248+
return torchdata.ConcatDataset(dataset_dicts)
249+
return dataset_dicts[0]
250+
242251
for dataset_name, dicts in zip(names, dataset_dicts):
243252
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
244253

@@ -250,9 +259,6 @@ def get_detection_dataset_dicts(
250259
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
251260
]
252261

253-
if isinstance(dataset_dicts[0], torchdata.Dataset):
254-
return torchdata.ConcatDataset(dataset_dicts)
255-
256262
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
257263

258264
has_instances = "annotations" in dataset_dicts[0]
@@ -351,18 +357,24 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
351357
if sampler is None:
352358
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
353359
logger = logging.getLogger(__name__)
354-
logger.info("Using training sampler {}".format(sampler_name))
355-
if sampler_name == "TrainingSampler":
356-
sampler = TrainingSampler(len(dataset))
357-
elif sampler_name == "RepeatFactorTrainingSampler":
358-
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
359-
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
360-
)
361-
sampler = RepeatFactorTrainingSampler(repeat_factors)
362-
elif sampler_name == "RandomSubsetTrainingSampler":
363-
sampler = RandomSubsetTrainingSampler(len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO)
360+
if isinstance(dataset, torchdata.IterableDataset):
361+
logger.info("Not using any sampler since the dataset is IterableDataset.")
362+
sampler = None
364363
else:
365-
raise ValueError("Unknown training sampler: {}".format(sampler_name))
364+
logger.info("Using training sampler {}".format(sampler_name))
365+
if sampler_name == "TrainingSampler":
366+
sampler = TrainingSampler(len(dataset))
367+
elif sampler_name == "RepeatFactorTrainingSampler":
368+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
369+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
370+
)
371+
sampler = RepeatFactorTrainingSampler(repeat_factors)
372+
elif sampler_name == "RandomSubsetTrainingSampler":
373+
sampler = RandomSubsetTrainingSampler(
374+
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
375+
)
376+
else:
377+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
366378

367379
return {
368380
"dataset": dataset,
@@ -461,7 +473,9 @@ def _test_loader_from_config(cfg, dataset_name, mapper=None):
461473
"dataset": dataset,
462474
"mapper": mapper,
463475
"num_workers": cfg.DATALOADER.NUM_WORKERS,
464-
"sampler": InferenceSampler(len(dataset)),
476+
"sampler": InferenceSampler(len(dataset))
477+
if not isinstance(dataset, torchdata.IterableDataset)
478+
else None,
465479
}
466480

467481

tests/data/test_dataset.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from iopath.common.file_io import LazyPath
1010

1111
from detectron2 import model_zoo
12-
from detectron2.config import instantiate
12+
from detectron2.config import get_cfg, instantiate
1313
from detectron2.data import (
14+
DatasetCatalog,
1415
DatasetFromList,
1516
MapDataset,
1617
ToIterableDataset,
@@ -112,6 +113,22 @@ def test_build_iterable_dataloader_train(self):
112113
dl = build_detection_train_loader(dataset=ds, **kwargs)
113114
next(iter(dl))
114115

116+
def test_build_iterable_dataloader_from_cfg(self):
117+
cfg = get_cfg()
118+
119+
class MyData(torch.utils.data.IterableDataset):
120+
def __iter__(self):
121+
while True:
122+
yield 1
123+
124+
cfg.DATASETS.TRAIN = ["iter_data"]
125+
DatasetCatalog.register("iter_data", lambda: MyData())
126+
dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False)
127+
next(iter(dl))
128+
129+
dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x)
130+
next(iter(dl))
131+
115132
def _check_is_range(self, data_loader, N):
116133
# check that data_loader produces range(N)
117134
data = list(iter(data_loader))

0 commit comments

Comments
 (0)