Skip to content

Commit a259651

Browse files
authored
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
1 parent 30a9443 commit a259651

File tree

8 files changed

+238
-57
lines changed

8 files changed

+238
-57
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
save_state_dict,
3030
sharded_optimizer_loading_epilogue,
3131
)
32+
from colossalai.cluster import ProcessGroupMesh
3233
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
3334
from colossalai.interface.optimizer import DistributedOptim
3435
from colossalai.logging import get_dist_logger
@@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
333334
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
334335
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
335336
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
337+
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
336338
"""
337339

338340
def __init__(
@@ -358,11 +360,16 @@ def __init__(
358360
cast_inputs: bool = True,
359361
fp8_communication: bool = False,
360362
use_fp8: bool = False,
363+
extra_dp_size: int = 1,
361364
) -> None:
362365
super().__init__()
363366
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
364367
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
365368
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
369+
if extra_dp_size > 1:
370+
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
371+
inner_dp_size = dist.get_world_size() // extra_dp_size
372+
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
366373
self.stage = stage
367374
self.precision = precision
368375
self.zero_optim_kwargs = dict(
@@ -383,6 +390,9 @@ def __init__(
383390
overlap_allgather=overlap_allgather,
384391
fp8_communication=fp8_communication,
385392
)
393+
if extra_dp_size > 1:
394+
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
395+
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
386396
self.lora_enabled = False
387397
self.verbose = verbose
388398
self.logger = get_dist_logger()

colossalai/zero/low_level/_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Tuple, Union
33

4+
import numpy as np
45
import torch
56
import torch.distributed as dist
67
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
209210
# update the tensor data
210211
for p, q in zip(tensor_list, updated_params):
211212
p.data = q.data
213+
214+
215+
def all_gather_into_flat_tensor_nd(
216+
output_tensor: torch.Tensor,
217+
input_tensor: torch.Tensor,
218+
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
219+
async_op: bool = False,
220+
):
221+
if isinstance(group, dist.ProcessGroup):
222+
group = (group,)
223+
sizes = [dist.get_world_size(pg) for pg in group]
224+
ranks = [dist.get_rank(pg) for pg in group]
225+
for i, pg in list(enumerate(group))[::-1]:
226+
if i == 0:
227+
out = output_tensor
228+
else:
229+
prev_sizes = sizes[:i]
230+
prev_ranks = ranks[:i]
231+
chunks = output_tensor.chunk(np.prod(prev_sizes))
232+
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
233+
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
234+
input_tensor = out
235+
return handle
236+
237+
238+
def get_nd_world_size(group) -> int:
239+
if isinstance(group, tuple):
240+
return int(np.prod([dist.get_world_size(pg) for pg in group]))
241+
else:
242+
return dist.get_world_size(group)
243+
244+
245+
def get_nd_rank(group) -> int:
246+
if isinstance(group, tuple):
247+
return np.ravel_multi_index(
248+
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
249+
)
250+
else:
251+
return dist.get_rank(group)

colossalai/zero/low_level/bookkeeping/base_store.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
from typing import Tuple, Union
2+
3+
import numpy as np
14
import torch.distributed as dist
25
from torch.distributed import ProcessGroup
36

47

58
class BaseStore:
6-
def __init__(self, torch_pg: ProcessGroup):
7-
self._world_size = dist.get_world_size(group=torch_pg)
8-
self._local_rank = dist.get_rank(group=torch_pg)
9+
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
10+
if isinstance(torch_pg, tuple):
11+
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
12+
self._world_size = int(np.prod(self.sizes))
13+
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
14+
else:
15+
self._world_size = dist.get_world_size(group=torch_pg)
16+
self._local_rank = dist.get_rank(group=torch_pg)
17+
self.sizes = [self._world_size]
918
self.torch_pg = torch_pg
1019

1120
@property

colossalai/zero/low_level/bookkeeping/tensor_bucket.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Optional
22

3+
import numpy as np
34
import torch
45
import torch.distributed as dist
56
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
67

78
from colossalai.quantization.fp8 import all_gather_fp8
9+
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
810

911

1012
class TensorBucket:
@@ -65,12 +67,18 @@ def unflatten_and_copy(self, flat_tensor):
6567

6668
def all_gather(self, group=None, fp8_communication: bool = False):
6769
flat = self.flatten()
68-
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
70+
if isinstance(group, tuple):
71+
world_size = np.prod([dist.get_world_size(pg) for pg in group])
72+
else:
73+
world_size = dist.get_world_size(group)
74+
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
6975
if fp8_communication:
76+
# TODO: fit fp8
7077
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
7178
else:
72-
dist.all_gather_into_tensor(buffer, flat, group=group)
73-
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
79+
# dist.all_gather_into_tensor(buffer, flat, group=group)
80+
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
81+
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
7482
# transpose the list of list
7583
unflat_buffers = list(map(list, zip(*unflat_buffers)))
7684
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

0 commit comments

Comments
 (0)