diff --git a/megatron/arguments.py b/megatron/arguments.py index 2be64b77d..e1a973b05 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -897,6 +897,8 @@ def __call__(self, parser, args, values, option_string=None): help='Warm up mmap files.') group.add_argument('--num-workers', type=int, default=2, help="Dataloader number of workers.") + group.add_argument('--valid-num-workers', type=int, default=2, + help="Dataloader number of workers for validation.") group.add_argument('--tokenizer-type', type=str, default=None, choices=['BertWordPieceLowerCase', diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 1cbeac312..c8109b3d2 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -22,7 +22,7 @@ from megatron import mpu -def build_pretraining_data_loader(dataset, consumed_samples): +def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): """Buld dataloader given an input dataset.""" if dataset is None: @@ -48,10 +48,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) + if num_workers is None: + num_workers = args.num_workers + # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, - num_workers=args.num_workers, + num_workers=num_workers, pin_memory=True) class MegatronPretrainingSampler: @@ -141,7 +144,7 @@ def __iter__(self): * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size - + g = torch.Generator() g.manual_seed(self.epoch) random_idx = torch.randperm(bucket_size, generator=g).tolist() diff --git a/megatron/training.py b/megatron/training.py index 84fd4eb9d..da8259450 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1132,7 +1132,12 @@ def build_train_valid_test_data_iterators( # We collapse None and empty list as both should mean we don't run validation # args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal - valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds)) + # + # XXX: we get a deadlock in the dataloader on multi-dataset eval, after the first dataset, + # possibly due to this bug in pytorch https://github.com/pytorch/pytorch/pull/25158. Using + # num_workers=0 to work around it - the training can't use that since it impacts throughput + # by a few percent + valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=args.valid_num_workers) for d in valid_ds] \ if valid_ds is not None else [] # We collapse None and empty list as both should mean we don't run test