Skip to content

Commit 0316cb7

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
support registering a dataset object; support customizing collate_fn
Reviewed By: zhanghang1989 Differential Revision: D31201334 fbshipit-source-id: 38928aa0eec0749af1edd10f40383257b29bec3c
1 parent 0a2a4a3 commit 0316cb7

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

detectron2/data/build.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, propo
242242
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
243243
]
244244

245+
if isinstance(dataset_dicts[0], torchdata.Dataset):
246+
return torchdata.ConcatDataset(dataset_dicts)
247+
245248
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
246249

247250
has_instances = "annotations" in dataset_dicts[0]
@@ -263,7 +266,13 @@ def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, propo
263266

264267

265268
def build_batch_data_loader(
266-
dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0
269+
dataset,
270+
sampler,
271+
total_batch_size,
272+
*,
273+
aspect_ratio_grouping=False,
274+
num_workers=0,
275+
collate_fn=None,
267276
):
268277
"""
269278
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
@@ -274,7 +283,7 @@ def build_batch_data_loader(
274283
dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
275284
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
276285
Must be provided iff. ``dataset`` is a map-style dataset.
277-
total_batch_size, aspect_ratio_grouping, num_workers): see
286+
total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
278287
:func:`build_detection_train_loader`.
279288
280289
Returns:
@@ -301,14 +310,17 @@ def build_batch_data_loader(
301310
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
302311
worker_init_fn=worker_init_reset_seed,
303312
) # yield individual mapped dict
304-
return AspectRatioGroupedDataset(data_loader, batch_size)
313+
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
314+
if collate_fn is None:
315+
return data_loader
316+
return MapDataset(data_loader, collate_fn)
305317
else:
306318
return torchdata.DataLoader(
307319
dataset,
308320
batch_size=batch_size,
309321
drop_last=True,
310322
num_workers=num_workers,
311-
collate_fn=trivial_batch_collator,
323+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
312324
worker_init_fn=worker_init_reset_seed,
313325
)
314326

@@ -356,7 +368,14 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
356368

357369
@configurable(from_config=_train_loader_from_config)
358370
def build_detection_train_loader(
359-
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
371+
dataset,
372+
*,
373+
mapper,
374+
sampler=None,
375+
total_batch_size,
376+
aspect_ratio_grouping=True,
377+
num_workers=0,
378+
collate_fn=None,
360379
):
361380
"""
362381
Build a dataloader for object detection with some default features.
@@ -380,6 +399,11 @@ def build_detection_train_loader(
380399
aspect ratio for efficiency. When enabled, it requires each
381400
element in dataset be a dict with keys "width" and "height".
382401
num_workers (int): number of parallel data loading workers
402+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
403+
Defaults to do no collation and return a list of data.
404+
No collation is OK for small batch size and simple data structures.
405+
If your batch size is large and each sample contains too many small tensors,
406+
it's more efficient to collate them in data loader.
383407
384408
Returns:
385409
torch.utils.data.DataLoader:
@@ -404,6 +428,7 @@ def build_detection_train_loader(
404428
total_batch_size,
405429
aspect_ratio_grouping=aspect_ratio_grouping,
406430
num_workers=num_workers,
431+
collate_fn=collate_fn,
407432
)
408433

409434

@@ -430,7 +455,7 @@ def _test_loader_from_config(cfg, dataset_name, mapper=None):
430455

431456

432457
@configurable(from_config=_test_loader_from_config)
433-
def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0):
458+
def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0, collate_fn=None):
434459
"""
435460
Similar to `build_detection_train_loader`, but uses a batch size of 1,
436461
and :class:`InferenceSampler`. This sampler coordinates all workers to
@@ -449,6 +474,8 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
449474
which splits the dataset across all workers. Sampler must be None
450475
if `dataset` is iterable.
451476
num_workers (int): number of parallel data loading workers
477+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
478+
Defaults to do no collation and return a list of data.
452479
453480
Returns:
454481
DataLoader: a torch DataLoader, that loads the given detection
@@ -479,7 +506,7 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
479506
batch_size=1,
480507
sampler=sampler,
481508
num_workers=num_workers,
482-
collate_fn=trivial_batch_collator,
509+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
483510
)
484511

485512

0 commit comments

Comments
 (0)