@@ -242,6 +242,9 @@ def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, propo
242242 for dataset_i_dicts , proposal_file in zip (dataset_dicts , proposal_files )
243243 ]
244244
245+ if isinstance (dataset_dicts [0 ], torchdata .Dataset ):
246+ return torchdata .ConcatDataset (dataset_dicts )
247+
245248 dataset_dicts = list (itertools .chain .from_iterable (dataset_dicts ))
246249
247250 has_instances = "annotations" in dataset_dicts [0 ]
@@ -263,7 +266,13 @@ def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, propo
263266
264267
265268def build_batch_data_loader (
266- dataset , sampler , total_batch_size , * , aspect_ratio_grouping = False , num_workers = 0
269+ dataset ,
270+ sampler ,
271+ total_batch_size ,
272+ * ,
273+ aspect_ratio_grouping = False ,
274+ num_workers = 0 ,
275+ collate_fn = None ,
267276):
268277 """
269278 Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
@@ -274,7 +283,7 @@ def build_batch_data_loader(
274283 dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
275284 sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
276285 Must be provided iff. ``dataset`` is a map-style dataset.
277- total_batch_size, aspect_ratio_grouping, num_workers) : see
286+ total_batch_size, aspect_ratio_grouping, num_workers, collate_fn : see
278287 :func:`build_detection_train_loader`.
279288
280289 Returns:
@@ -301,14 +310,17 @@ def build_batch_data_loader(
301310 collate_fn = operator .itemgetter (0 ), # don't batch, but yield individual elements
302311 worker_init_fn = worker_init_reset_seed ,
303312 ) # yield individual mapped dict
304- return AspectRatioGroupedDataset (data_loader , batch_size )
313+ data_loader = AspectRatioGroupedDataset (data_loader , batch_size )
314+ if collate_fn is None :
315+ return data_loader
316+ return MapDataset (data_loader , collate_fn )
305317 else :
306318 return torchdata .DataLoader (
307319 dataset ,
308320 batch_size = batch_size ,
309321 drop_last = True ,
310322 num_workers = num_workers ,
311- collate_fn = trivial_batch_collator ,
323+ collate_fn = trivial_batch_collator if collate_fn is None else collate_fn ,
312324 worker_init_fn = worker_init_reset_seed ,
313325 )
314326
@@ -356,7 +368,14 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
356368
357369@configurable (from_config = _train_loader_from_config )
358370def build_detection_train_loader (
359- dataset , * , mapper , sampler = None , total_batch_size , aspect_ratio_grouping = True , num_workers = 0
371+ dataset ,
372+ * ,
373+ mapper ,
374+ sampler = None ,
375+ total_batch_size ,
376+ aspect_ratio_grouping = True ,
377+ num_workers = 0 ,
378+ collate_fn = None ,
360379):
361380 """
362381 Build a dataloader for object detection with some default features.
@@ -380,6 +399,11 @@ def build_detection_train_loader(
380399 aspect ratio for efficiency. When enabled, it requires each
381400 element in dataset be a dict with keys "width" and "height".
382401 num_workers (int): number of parallel data loading workers
402+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
403+ Defaults to do no collation and return a list of data.
404+ No collation is OK for small batch size and simple data structures.
405+ If your batch size is large and each sample contains too many small tensors,
406+ it's more efficient to collate them in data loader.
383407
384408 Returns:
385409 torch.utils.data.DataLoader:
@@ -404,6 +428,7 @@ def build_detection_train_loader(
404428 total_batch_size ,
405429 aspect_ratio_grouping = aspect_ratio_grouping ,
406430 num_workers = num_workers ,
431+ collate_fn = collate_fn ,
407432 )
408433
409434
@@ -430,7 +455,7 @@ def _test_loader_from_config(cfg, dataset_name, mapper=None):
430455
431456
432457@configurable (from_config = _test_loader_from_config )
433- def build_detection_test_loader (dataset , * , mapper , sampler = None , num_workers = 0 ):
458+ def build_detection_test_loader (dataset , * , mapper , sampler = None , num_workers = 0 , collate_fn = None ):
434459 """
435460 Similar to `build_detection_train_loader`, but uses a batch size of 1,
436461 and :class:`InferenceSampler`. This sampler coordinates all workers to
@@ -449,6 +474,8 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
449474 which splits the dataset across all workers. Sampler must be None
450475 if `dataset` is iterable.
451476 num_workers (int): number of parallel data loading workers
477+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
478+ Defaults to do no collation and return a list of data.
452479
453480 Returns:
454481 DataLoader: a torch DataLoader, that loads the given detection
@@ -479,7 +506,7 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
479506 batch_size = 1 ,
480507 sampler = sampler ,
481508 num_workers = num_workers ,
482- collate_fn = trivial_batch_collator ,
509+ collate_fn = trivial_batch_collator if collate_fn is None else collate_fn ,
483510 )
484511
485512
0 commit comments