Skip to content

Commit 6cc864b

Browse files
authored
drop dispatching dataloader (#245)
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 99d3937 commit 6cc864b

File tree

5 files changed

+32
-316
lines changed

5 files changed

+32
-316
lines changed

configs/finetuning-example.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,3 @@ distributed_args:
7474
# use ZeRO-3 for model sharding, saves most memory but needs more communication. this is fine since we are doing training on 2 GPUs and they are connected via NVLink
7575
stage: 3
7676
torch_compile: true
77-
# this will load dataset only on the first GPU and send part of the data to the other GPUs, not recommended unless the datasets are immensely large
78-
dispatching_dataloader: false

lm_engine/arguments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,6 @@ class DistributedArgs(BaseArgs):
263263
communication_dtype: str | None = None
264264
# whether to use torch.compile
265265
torch_compile: bool = False
266-
# whether to use a dispatching dataloader
267-
dispatching_dataloader: bool = False
268266
# tensor parallel world size
269267
tensor_parallel_world_size: int = 1
270268
# whether to use sequence parallel

lm_engine/data/__init__.py

Lines changed: 12 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55
import logging
66
from functools import partial
77

8-
import torch
9-
import torch.distributed
10-
118
from ..arguments import DatasetArgs, InferenceArgs, TrainingArgs
129
from ..enums import DatasetSplit, Mode
1310
from ..tokenizers import TOKENIZER_TYPE
1411
from ..utils import ProcessGroupManager, log_rank_0, run_rank_n
1512
from .base import BaseDataset, BlendedDatasets
16-
from .dataloader import DispatchingDataLoader, ResumableDataLoader
13+
from .dataloader import ResumableDataLoader
1714
from .debug import DebugDataset
1815
from .huggingface import HuggingFaceDataset
1916
from .ibm import get_ibm_dataloaders
@@ -107,111 +104,6 @@ def get_finetuning_dataloader(
107104
if ProcessGroupManager.get_tensor_parallel_rank() != 0:
108105
return
109106

110-
if args.distributed_args.dispatching_dataloader:
111-
assert (
112-
ProcessGroupManager.get_tensor_parallel_world_size() == 1
113-
), "tensor parallel doesn't support dispatching dataloader"
114-
115-
dataloader = _get_dispatching_dataloader(args, split=split, mode=mode, tokenizer=tokenizer)
116-
else:
117-
dataloader = _get_non_dispatching_dataloader(args, split=split, mode=mode, tokenizer=tokenizer)
118-
119-
return dataloader
120-
121-
122-
def get_pretraining_dataloaders(
123-
args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int
124-
) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]:
125-
if args.datasets[0].class_name == "MegatronDataset":
126-
dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples)
127-
elif args.datasets[0].class_name == "IBMDataset":
128-
dataloaders = get_ibm_dataloaders(args, tokenizer)
129-
130-
return dataloaders
131-
132-
133-
def _get_dispatching_dataloader(
134-
args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
135-
) -> ResumableDataLoader:
136-
micro_batch_size = args.training_parameters.micro_batch_size
137-
138-
num_ranks_per_node = torch.cuda.device_count()
139-
node_rank = ProcessGroupManager.get_global_rank() // num_ranks_per_node
140-
num_nodes = ProcessGroupManager.get_world_size() // num_ranks_per_node
141-
142-
def _get_source_broadcast_mapping() -> dict:
143-
result = {}
144-
for i in range(num_nodes):
145-
source = i * num_ranks_per_node
146-
ranks = list(range(source, source + num_ranks_per_node))
147-
result[source] = torch.distributed.new_group(ranks)
148-
return result
149-
150-
source_broadcast_mapping = _get_source_broadcast_mapping()
151-
152-
# check if node's first rank
153-
if ProcessGroupManager.get_global_rank() == node_rank * num_ranks_per_node:
154-
datasets_list, data_sampling_ratios = get_datasets_list(
155-
dataset_args_list=args.datasets, split=split, mode=Mode.training, tokenizer=tokenizer
156-
)
157-
158-
if len(datasets_list) == 0:
159-
return None
160-
161-
blended_dataset = BlendedDatasets(datasets=datasets_list, split=split)
162-
data_sampling_ratios = [1] if len(datasets_list) == 1 else data_sampling_ratios
163-
164-
# each node is given a data sampler
165-
# TODO modify this when we add model parallelism
166-
167-
# sampler routes to the dispatching parent worker
168-
sampler = BlendedDistributedSampler(
169-
dataset=blended_dataset,
170-
data_sampling_ratios=data_sampling_ratios,
171-
num_replicas=num_nodes,
172-
rank=node_rank,
173-
ignore_sampling_proportion_for_validation=args.training_parameters.ignore_sampling_proportion_for_validation,
174-
shuffle=split == DatasetSplit.train,
175-
seed=args.random_args.seed,
176-
drop_last=False,
177-
)
178-
else:
179-
blended_dataset = None
180-
data_sampling_ratios = None
181-
sampler = None
182-
183-
# dataloader does local dispatching and thus needs source_rank and broadcast_ranks
184-
dataloader = DispatchingDataLoader(
185-
blended_dataset,
186-
batch_size=micro_batch_size,
187-
sampler=sampler,
188-
collate_fn=partial(
189-
collate_fn,
190-
mode=mode,
191-
loss_mask=args.training_parameters.loss_mask,
192-
eos_token_id=tokenizer.eos_token_id,
193-
use_padding_free_transformer=args.model_args.use_padding_free_transformer,
194-
pad_to_multiple_of=ProcessGroupManager.get_tensor_parallel_world_size(),
195-
),
196-
source_broadcast_mapping=source_broadcast_mapping,
197-
broadcast_world_size=num_ranks_per_node,
198-
)
199-
200-
_log_dataset(
201-
blended_dataset=blended_dataset,
202-
sampler=sampler,
203-
split=split,
204-
num_training_steps=args.training_parameters.num_training_steps,
205-
gradient_accumulation_steps=args.training_parameters.gradient_accumulation_steps,
206-
micro_batch_size=args.training_parameters.micro_batch_size,
207-
)
208-
209-
return dataloader
210-
211-
212-
def _get_non_dispatching_dataloader(
213-
args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
214-
) -> ResumableDataLoader:
215107
micro_batch_size = args.training_parameters.micro_batch_size
216108

217109
datasets_list, data_sampling_ratios = get_datasets_list(
@@ -262,6 +154,17 @@ def _get_non_dispatching_dataloader(
262154
return dataloader
263155

264156

157+
def get_pretraining_dataloaders(
158+
args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int
159+
) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]:
160+
if args.datasets[0].class_name == "MegatronDataset":
161+
dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples)
162+
elif args.datasets[0].class_name == "IBMDataset":
163+
dataloaders = get_ibm_dataloaders(args, tokenizer)
164+
165+
return dataloaders
166+
167+
265168
@run_rank_n
266169
def _log_dataset(
267170
blended_dataset: BlendedDatasets,

lm_engine/data/dataloader.py

Lines changed: 1 addition & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import Callable, Iterable, Iterator
8-
9-
import torch
10-
import torch.distributed
11-
from torch.distributed import ProcessGroup
12-
from torch.utils.data import DataLoader, Dataset, Sampler
13-
14-
from ..communication import Communication
15-
from ..utils import ProcessGroupManager
7+
from torch.utils.data import DataLoader
168

179

1810
class ResumableDataLoader(DataLoader):
@@ -22,106 +14,3 @@ def state_dict(self) -> dict:
2214
def load_state_dict(self, state_dict: dict) -> None:
2315
self.dataset.load_state_dict(state_dict.get("dataset"))
2416
self.sampler.load_state_dict(state_dict.get("sampler"))
25-
26-
27-
class DispatchingDataLoader(ResumableDataLoader):
28-
def __init__(
29-
self,
30-
dataset: Dataset,
31-
batch_size: int | None = 1,
32-
sampler: Sampler | Iterable | None = None,
33-
batch_sampler: Sampler[list] | Iterable[list] | None = None,
34-
num_workers: int = 0,
35-
collate_fn: Callable | None = None,
36-
pin_memory: bool = False,
37-
drop_last: bool = False,
38-
source_broadcast_mapping: dict[int, ProcessGroup] | None = None,
39-
broadcast_world_size: int | None = None,
40-
static_shape_per_rank: tuple[int, int] | None = None,
41-
keys: list[str] = ["input_ids", "attention_mask", "labels"],
42-
) -> DispatchingDataLoader:
43-
self.broadcast_world_size = broadcast_world_size
44-
45-
self.is_source, self.source_rank, self.local_rank_in_broadcast_group, self.broadcast_group = (
46-
get_source_and_broadcast_group(source_broadcast_mapping)
47-
)
48-
49-
super().__init__(
50-
dataset=dataset,
51-
batch_size=batch_size * self.broadcast_world_size if batch_sampler is None else 1,
52-
sampler=sampler,
53-
batch_sampler=batch_sampler,
54-
num_workers=num_workers,
55-
collate_fn=collate_fn,
56-
pin_memory=pin_memory,
57-
drop_last=drop_last,
58-
)
59-
60-
_length = torch.tensor(
61-
[super().__len__() if self.is_source else 0], dtype=torch.long, device=torch.cuda.current_device()
62-
)
63-
torch.distributed.broadcast(_length, src=self.source_rank, group=self.broadcast_group)
64-
self._length = _length.item()
65-
66-
self.global_static_shape = None
67-
if static_shape_per_rank is not None:
68-
self.global_static_shape = (static_shape_per_rank[0] * self.broadcast_world_size, static_shape_per_rank[1])
69-
70-
self.keys = keys
71-
72-
def __iter__(self) -> Iterator[dict]:
73-
iterator = super().__iter__() if self.is_source else range(self._length)
74-
75-
for batch in iterator:
76-
# if using dynamic shapes at every batch or when batch buffer is None during static batch, we need to get shape
77-
# send/recv tensor shapes
78-
if self.global_static_shape is None:
79-
batch_shape = batch[self.keys[0]].shape if self.is_source else None
80-
batch_shape = Communication.broadcast_object(
81-
batch_shape, src=self.source_rank, group=self.broadcast_group
82-
)
83-
else:
84-
batch_shape = self.global_static_shape
85-
86-
if self.is_source:
87-
for key in self.keys:
88-
batch[key] = batch[key].to(torch.cuda.current_device())
89-
else:
90-
batch = {
91-
key: torch.empty(batch_shape, dtype=torch.long, device=torch.cuda.current_device())
92-
for key in self.keys
93-
}
94-
95-
for key in self.keys:
96-
# send/recv batch
97-
torch.distributed.broadcast(batch[key], src=self.source_rank, group=self.broadcast_group)
98-
99-
# slice batch
100-
local_batch_size = batch[key].shape[0] // self.broadcast_world_size
101-
batch[key] = batch[key][
102-
self.local_rank_in_broadcast_group
103-
* local_batch_size : (self.local_rank_in_broadcast_group + 1)
104-
* local_batch_size
105-
]
106-
107-
yield batch
108-
109-
def __len__(self) -> int:
110-
return self._length
111-
112-
113-
def get_source_and_broadcast_group(
114-
source_broadcast_mapping: dict[int, ProcessGroup],
115-
) -> tuple[bool, int, int, ProcessGroup]:
116-
global_rank = ProcessGroupManager.get_global_rank()
117-
118-
for source_rank, broadcast_group in source_broadcast_mapping.items():
119-
ranks = torch.distributed.get_process_group_ranks(broadcast_group)
120-
121-
if global_rank in ranks:
122-
is_source = global_rank == source_rank
123-
local_rank_in_broadcast_group = ranks.index(global_rank)
124-
125-
return is_source, source_rank, local_rank_in_broadcast_group, broadcast_group
126-
127-
assert False, "code shouldn't reach here"

0 commit comments

Comments
 (0)