|
5 | 5 | import logging |
6 | 6 | from functools import partial |
7 | 7 |
|
8 | | -import torch |
9 | | -import torch.distributed |
10 | | - |
11 | 8 | from ..arguments import DatasetArgs, InferenceArgs, TrainingArgs |
12 | 9 | from ..enums import DatasetSplit, Mode |
13 | 10 | from ..tokenizers import TOKENIZER_TYPE |
14 | 11 | from ..utils import ProcessGroupManager, log_rank_0, run_rank_n |
15 | 12 | from .base import BaseDataset, BlendedDatasets |
16 | | -from .dataloader import DispatchingDataLoader, ResumableDataLoader |
| 13 | +from .dataloader import ResumableDataLoader |
17 | 14 | from .debug import DebugDataset |
18 | 15 | from .huggingface import HuggingFaceDataset |
19 | 16 | from .ibm import get_ibm_dataloaders |
@@ -107,111 +104,6 @@ def get_finetuning_dataloader( |
107 | 104 | if ProcessGroupManager.get_tensor_parallel_rank() != 0: |
108 | 105 | return |
109 | 106 |
|
110 | | - if args.distributed_args.dispatching_dataloader: |
111 | | - assert ( |
112 | | - ProcessGroupManager.get_tensor_parallel_world_size() == 1 |
113 | | - ), "tensor parallel doesn't support dispatching dataloader" |
114 | | - |
115 | | - dataloader = _get_dispatching_dataloader(args, split=split, mode=mode, tokenizer=tokenizer) |
116 | | - else: |
117 | | - dataloader = _get_non_dispatching_dataloader(args, split=split, mode=mode, tokenizer=tokenizer) |
118 | | - |
119 | | - return dataloader |
120 | | - |
121 | | - |
122 | | -def get_pretraining_dataloaders( |
123 | | - args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int |
124 | | -) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]: |
125 | | - if args.datasets[0].class_name == "MegatronDataset": |
126 | | - dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples) |
127 | | - elif args.datasets[0].class_name == "IBMDataset": |
128 | | - dataloaders = get_ibm_dataloaders(args, tokenizer) |
129 | | - |
130 | | - return dataloaders |
131 | | - |
132 | | - |
133 | | -def _get_dispatching_dataloader( |
134 | | - args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE |
135 | | -) -> ResumableDataLoader: |
136 | | - micro_batch_size = args.training_parameters.micro_batch_size |
137 | | - |
138 | | - num_ranks_per_node = torch.cuda.device_count() |
139 | | - node_rank = ProcessGroupManager.get_global_rank() // num_ranks_per_node |
140 | | - num_nodes = ProcessGroupManager.get_world_size() // num_ranks_per_node |
141 | | - |
142 | | - def _get_source_broadcast_mapping() -> dict: |
143 | | - result = {} |
144 | | - for i in range(num_nodes): |
145 | | - source = i * num_ranks_per_node |
146 | | - ranks = list(range(source, source + num_ranks_per_node)) |
147 | | - result[source] = torch.distributed.new_group(ranks) |
148 | | - return result |
149 | | - |
150 | | - source_broadcast_mapping = _get_source_broadcast_mapping() |
151 | | - |
152 | | - # check if node's first rank |
153 | | - if ProcessGroupManager.get_global_rank() == node_rank * num_ranks_per_node: |
154 | | - datasets_list, data_sampling_ratios = get_datasets_list( |
155 | | - dataset_args_list=args.datasets, split=split, mode=Mode.training, tokenizer=tokenizer |
156 | | - ) |
157 | | - |
158 | | - if len(datasets_list) == 0: |
159 | | - return None |
160 | | - |
161 | | - blended_dataset = BlendedDatasets(datasets=datasets_list, split=split) |
162 | | - data_sampling_ratios = [1] if len(datasets_list) == 1 else data_sampling_ratios |
163 | | - |
164 | | - # each node is given a data sampler |
165 | | - # TODO modify this when we add model parallelism |
166 | | - |
167 | | - # sampler routes to the dispatching parent worker |
168 | | - sampler = BlendedDistributedSampler( |
169 | | - dataset=blended_dataset, |
170 | | - data_sampling_ratios=data_sampling_ratios, |
171 | | - num_replicas=num_nodes, |
172 | | - rank=node_rank, |
173 | | - ignore_sampling_proportion_for_validation=args.training_parameters.ignore_sampling_proportion_for_validation, |
174 | | - shuffle=split == DatasetSplit.train, |
175 | | - seed=args.random_args.seed, |
176 | | - drop_last=False, |
177 | | - ) |
178 | | - else: |
179 | | - blended_dataset = None |
180 | | - data_sampling_ratios = None |
181 | | - sampler = None |
182 | | - |
183 | | - # dataloader does local dispatching and thus needs source_rank and broadcast_ranks |
184 | | - dataloader = DispatchingDataLoader( |
185 | | - blended_dataset, |
186 | | - batch_size=micro_batch_size, |
187 | | - sampler=sampler, |
188 | | - collate_fn=partial( |
189 | | - collate_fn, |
190 | | - mode=mode, |
191 | | - loss_mask=args.training_parameters.loss_mask, |
192 | | - eos_token_id=tokenizer.eos_token_id, |
193 | | - use_padding_free_transformer=args.model_args.use_padding_free_transformer, |
194 | | - pad_to_multiple_of=ProcessGroupManager.get_tensor_parallel_world_size(), |
195 | | - ), |
196 | | - source_broadcast_mapping=source_broadcast_mapping, |
197 | | - broadcast_world_size=num_ranks_per_node, |
198 | | - ) |
199 | | - |
200 | | - _log_dataset( |
201 | | - blended_dataset=blended_dataset, |
202 | | - sampler=sampler, |
203 | | - split=split, |
204 | | - num_training_steps=args.training_parameters.num_training_steps, |
205 | | - gradient_accumulation_steps=args.training_parameters.gradient_accumulation_steps, |
206 | | - micro_batch_size=args.training_parameters.micro_batch_size, |
207 | | - ) |
208 | | - |
209 | | - return dataloader |
210 | | - |
211 | | - |
212 | | -def _get_non_dispatching_dataloader( |
213 | | - args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE |
214 | | -) -> ResumableDataLoader: |
215 | 107 | micro_batch_size = args.training_parameters.micro_batch_size |
216 | 108 |
|
217 | 109 | datasets_list, data_sampling_ratios = get_datasets_list( |
@@ -262,6 +154,17 @@ def _get_non_dispatching_dataloader( |
262 | 154 | return dataloader |
263 | 155 |
|
264 | 156 |
|
| 157 | +def get_pretraining_dataloaders( |
| 158 | + args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int |
| 159 | +) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]: |
| 160 | + if args.datasets[0].class_name == "MegatronDataset": |
| 161 | + dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples) |
| 162 | + elif args.datasets[0].class_name == "IBMDataset": |
| 163 | + dataloaders = get_ibm_dataloaders(args, tokenizer) |
| 164 | + |
| 165 | + return dataloaders |
| 166 | + |
| 167 | + |
265 | 168 | @run_rank_n |
266 | 169 | def _log_dataset( |
267 | 170 | blended_dataset: BlendedDatasets, |
|
0 commit comments