@@ -239,6 +239,15 @@ def get_detection_dataset_dicts(
239239 names = [names ]
240240 assert len (names ), names
241241 dataset_dicts = [DatasetCatalog .get (dataset_name ) for dataset_name in names ]
242+
243+ if isinstance (dataset_dicts [0 ], torchdata .Dataset ):
244+ if len (dataset_dicts ) > 1 :
245+ # ConcatDataset does not work for iterable style dataset.
246+ # We could support concat for iterable as well, but it's often
247+ # not a good idea to concat iterables anyway.
248+ return torchdata .ConcatDataset (dataset_dicts )
249+ return dataset_dicts [0 ]
250+
242251 for dataset_name , dicts in zip (names , dataset_dicts ):
243252 assert len (dicts ), "Dataset '{}' is empty!" .format (dataset_name )
244253
@@ -250,9 +259,6 @@ def get_detection_dataset_dicts(
250259 for dataset_i_dicts , proposal_file in zip (dataset_dicts , proposal_files )
251260 ]
252261
253- if isinstance (dataset_dicts [0 ], torchdata .Dataset ):
254- return torchdata .ConcatDataset (dataset_dicts )
255-
256262 dataset_dicts = list (itertools .chain .from_iterable (dataset_dicts ))
257263
258264 has_instances = "annotations" in dataset_dicts [0 ]
@@ -351,18 +357,24 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
351357 if sampler is None :
352358 sampler_name = cfg .DATALOADER .SAMPLER_TRAIN
353359 logger = logging .getLogger (__name__ )
354- logger .info ("Using training sampler {}" .format (sampler_name ))
355- if sampler_name == "TrainingSampler" :
356- sampler = TrainingSampler (len (dataset ))
357- elif sampler_name == "RepeatFactorTrainingSampler" :
358- repeat_factors = RepeatFactorTrainingSampler .repeat_factors_from_category_frequency (
359- dataset , cfg .DATALOADER .REPEAT_THRESHOLD
360- )
361- sampler = RepeatFactorTrainingSampler (repeat_factors )
362- elif sampler_name == "RandomSubsetTrainingSampler" :
363- sampler = RandomSubsetTrainingSampler (len (dataset ), cfg .DATALOADER .RANDOM_SUBSET_RATIO )
360+ if isinstance (dataset , torchdata .IterableDataset ):
361+ logger .info ("Not using any sampler since the dataset is IterableDataset." )
362+ sampler = None
364363 else :
365- raise ValueError ("Unknown training sampler: {}" .format (sampler_name ))
364+ logger .info ("Using training sampler {}" .format (sampler_name ))
365+ if sampler_name == "TrainingSampler" :
366+ sampler = TrainingSampler (len (dataset ))
367+ elif sampler_name == "RepeatFactorTrainingSampler" :
368+ repeat_factors = RepeatFactorTrainingSampler .repeat_factors_from_category_frequency (
369+ dataset , cfg .DATALOADER .REPEAT_THRESHOLD
370+ )
371+ sampler = RepeatFactorTrainingSampler (repeat_factors )
372+ elif sampler_name == "RandomSubsetTrainingSampler" :
373+ sampler = RandomSubsetTrainingSampler (
374+ len (dataset ), cfg .DATALOADER .RANDOM_SUBSET_RATIO
375+ )
376+ else :
377+ raise ValueError ("Unknown training sampler: {}" .format (sampler_name ))
366378
367379 return {
368380 "dataset" : dataset ,
@@ -461,7 +473,9 @@ def _test_loader_from_config(cfg, dataset_name, mapper=None):
461473 "dataset" : dataset ,
462474 "mapper" : mapper ,
463475 "num_workers" : cfg .DATALOADER .NUM_WORKERS ,
464- "sampler" : InferenceSampler (len (dataset )),
476+ "sampler" : InferenceSampler (len (dataset ))
477+ if not isinstance (dataset , torchdata .IterableDataset )
478+ else None ,
465479 }
466480
467481
0 commit comments