Skip to content

Commit bccd2bc

Browse files
committed
let allgather and alltoall execute in parallel when both attention and MOE used TP
Signed-off-by: taozhiwei <[email protected]>
1 parent b00b75f commit bccd2bc

File tree

2 files changed

+154
-24
lines changed

2 files changed

+154
-24
lines changed

deepspeed/moe/sharded_moe.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import torch.nn.functional as F
2727
from deepspeed.utils import groups
2828
from .mappings import drop_tokens, gather_tokens
29-
3029
if TYPE_CHECKING:
3130
Base = Module[Tensor]
3231
else:
@@ -96,16 +95,19 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
9695
class _AllToAll(torch.autograd.Function):
9796

9897
@staticmethod
99-
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
98+
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, async_op=False) -> Tensor: # type: ignore
10099
ctx.group = group
101100
input = input.contiguous()
102101
output = torch.empty_like(input)
103-
dist.all_to_all_single(output, input, group=group)
104-
return output
102+
work = dist.all_to_all_single(output, input, group=group, async_op=async_op)
103+
if async_op:
104+
return output, work
105+
else:
106+
return output
105107

106108
@staticmethod
107109
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
108-
return (None, _AllToAll.apply(ctx.group, *grad_output))
110+
return (None, _AllToAll.apply(ctx.group, *grad_output), None)
109111

110112

111113
# einsum rewrites are on par or more performant
@@ -550,6 +552,7 @@ class MOELayer(Base):
550552
expert (torch.nn.Module):
551553
expert network
552554
"""
555+
d2d_stream = torch.cuda.Stream()
553556

554557
def __init__(self,
555558
gate: Module,
@@ -572,6 +575,8 @@ def __init__(self,
572575
self.wall_clock_breakdown = False
573576

574577
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
578+
self.enable_pipelie = True
579+
self.shard_num = 4
575580

576581
if self.use_tutel:
577582
logger.info('Using Tutel optimizations.')
@@ -586,8 +591,54 @@ def _set_ep_group(self, ep_group):
586591
self.ep_group = ep_group
587592
self.gate._set_ep_group(ep_group)
588593

589-
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
594+
# During multi machine MOE training, alltoall is the communication between machines,
595+
# allgather is the communication within machines. They use different communication links,
596+
# so they can be executed in parallel
597+
# input shape (E,C,M),Shard input in C dim, first execute alltoall on the shard,
598+
# So the allgather of this shard and the alltoall of the next shard are executed in parallel
599+
# A E I M
600+
# A1 E1 I1 M1
601+
# A2 E2 I2 M2
602+
# A3 E3 I3 M3
603+
# A4 E4 I4 M4
604+
def pipeline_alltoall_with_allgather(self, input, shard_dim=1) -> Tensor:
605+
if not self.enable_pipelie:
606+
input = _AllToAll.apply(self.ep_group, input)
607+
input = gather_tokens(input, dim=shard_dim)
608+
return input
609+
610+
assert self.shard_num > 0, f"shard_num must be a positive number,but get is {self.shard_num}"
611+
input_chunks = list(input.chunk(self.shard_num, dim=shard_dim))
612+
world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
613+
dims = list(input.size())
614+
dims[shard_dim] = dims[shard_dim] * world_size
615+
output = torch.empty(dims, device=input.device)
616+
input_gather_dim_len = input.shape[shard_dim]
617+
have_gather_len = 0
618+
works = []
619+
for i in range(len(input_chunks)):
620+
input_chunks[i], work = _AllToAll.apply(self.ep_group, input_chunks[i], True)
621+
works.append(work)
622+
623+
current_stream = torch.cuda.current_stream()
624+
for i in range(len(input_chunks)):
625+
works[i].wait()
626+
# we use dim 0 do allgather and chunk, so we can avoid unnecessary cat in gather_tokens
627+
gather_out = gather_tokens(input_chunks[i], dim=0)
628+
gather_list = gather_out.chunk(world_size, dim=0)
629+
dim_len = gather_list[0].shape[shard_dim]
630+
MOELayer.d2d_stream.wait_stream(current_stream)
631+
632+
for j in range(len(gather_list)):
633+
start = input_gather_dim_len * j + have_gather_len
634+
with torch.cuda.stream(MOELayer.d2d_stream):
635+
torch.narrow(output, shard_dim, start, dim_len).copy_(gather_list[j])
636+
have_gather_len += dim_len
637+
638+
current_stream.wait_stream(MOELayer.d2d_stream)
639+
return output
590640

641+
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
591642
if self.wall_clock_breakdown:
592643
self.timers(MOE_TIMER).start()
593644

@@ -611,9 +662,6 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
611662
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
612663
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
613664

614-
if self.wall_clock_breakdown:
615-
self.timers(FIRST_ALLTOALL_TIMER).start()
616-
617665
tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
618666
if tensor_model_world_size > 1:
619667
# If the non-expert is tensor-parallel,
@@ -628,18 +676,17 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
628676
# an allgather to ensure correctness,
629677
dispatched_input = drop_tokens(dispatched_input, dim=1)
630678

631-
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
632-
633679
if self.wall_clock_breakdown:
634-
self.timers(FIRST_ALLTOALL_TIMER).stop()
635-
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
680+
self.timers(FIRST_ALLTOALL_TIMER).start()
636681

637682
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
638-
# if both expert and non-expert are tensor-parallel
639-
# the dropped duplicate tokens need to be gathered on each
640-
# tensor parallel rank again to ensure correctness
641-
dispatched_input = gather_tokens(dispatched_input, dim=1)
683+
dispatched_input = self.pipeline_alltoall_with_allgather(dispatched_input)
684+
else:
685+
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
642686

687+
if self.wall_clock_breakdown:
688+
self.timers(FIRST_ALLTOALL_TIMER).stop()
689+
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
643690
# Re-shape after all-to-all: ecm -> gecm
644691
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
645692
expert_output = self.experts(dispatched_input)
@@ -654,18 +701,12 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
654701
if self.wall_clock_breakdown:
655702
self.timers(SECOND_ALLTOALL_TIMER).start()
656703

657-
expert_output = _AllToAll.apply(self.ep_group, expert_output)
704+
expert_output = self.pipeline_alltoall_with_allgather(expert_output)
658705

659706
if self.wall_clock_breakdown:
660707
self.timers(SECOND_ALLTOALL_TIMER).stop()
661708
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
662709

663-
if tensor_model_world_size > 1:
664-
# the dropped duplicate tokens need to be gathered on each
665-
# tensor parallel rank again for the tensor-parallel
666-
# non-expert of the next layer.
667-
expert_output = gather_tokens(expert_output, dim=1)
668-
669710
if self.use_tutel:
670711
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
671712
else:

tests/unit/moe/test_pipeline.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import deepspeed
3+
import pytest
4+
from unit.common import DistributedTest
5+
from deepspeed import get_accelerator
6+
from deepspeed.moe.sharded_moe import _AllToAll
7+
from deepspeed.moe.mappings import gather_tokens
8+
from deepspeed.moe.layer import MoE
9+
10+
11+
class MPU():
12+
13+
def __init__(self, tp_world_size):
14+
self.rank = deepspeed.comm.get_rank()
15+
self.world_size = deepspeed.comm.get_world_size()
16+
self.tp_world_size = tp_world_size
17+
18+
for i in range(0, self.world_size, tp_world_size):
19+
ranks = range(i, i + tp_world_size)
20+
group = deepspeed.comm.new_group(ranks)
21+
if self.rank in ranks:
22+
self.tp_group = group
23+
24+
for i in range(0, tp_world_size):
25+
ranks = range(i, self.world_size, tp_world_size)
26+
group = deepspeed.comm.new_group(ranks)
27+
if self.rank in ranks:
28+
self.dp_group = group
29+
30+
def get_model_parallel_rank(self):
31+
return self.rank % self.tp_world_size
32+
33+
def get_model_parallel_world_size(self):
34+
return self.tp_world_size
35+
36+
def get_data_parallel_rank(self):
37+
return self.rank // self.tp_world_size
38+
39+
def get_data_parallel_world_size(self):
40+
return self.world_size // self.tp_world_size
41+
42+
def get_data_parallel_group(self):
43+
return self.dp_group
44+
45+
def get_model_parallel_group(self):
46+
return self.tp_group
47+
48+
49+
@pytest.mark.parametrize("shard_num", [6, 10])
50+
@pytest.mark.parametrize("C, M, scale", [(92, 32, 1),(209, 128, 5)])
51+
class TestPipelineCommunication(DistributedTest):
52+
world_size = 8
53+
54+
def test(self, shard_num, C, M, scale):
55+
tp_size = 2
56+
world_size = deepspeed.comm.get_world_size()
57+
E = world_size
58+
ep_size = 4
59+
config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
60+
hidden_dim = M
61+
device = get_accelerator().current_device_name()
62+
tensor_parallel_expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, 4 * hidden_dim // tp_size),
63+
torch.nn.ReLU(),
64+
torch.nn.Linear(4 * hidden_dim // tp_size, hidden_dim))
65+
66+
model = MoE(
67+
hidden_size=hidden_dim,
68+
expert=tensor_parallel_expert,
69+
num_experts=world_size * scale,
70+
ep_size=ep_size,
71+
use_residual=True,
72+
enable_expert_tensor_parallelism=True,
73+
)
74+
optimizer = torch.optim.AdamW(params=model.parameters())
75+
model, _, _, _ = deepspeed.initialize(config=config_dict,
76+
model=model,
77+
optimizer=optimizer,
78+
dist_init_required=False,
79+
mpu=MPU(tp_size))
80+
model.deepspeed_moe.shard_num = shard_num
81+
input = torch.rand(E, C, M, device=device)
82+
83+
# pipeline alltoall with allgather
84+
pipeline_output = model.deepspeed_moe.pipeline_alltoall_with_allgather(input)
85+
86+
# first alltoall, then allgather
87+
alltoall_output = _AllToAll.apply(model.deepspeed_moe.ep_group, input)
88+
gather_output = gather_tokens(alltoall_output, dim=1)
89+
assert torch.allclose(pipeline_output, gather_output, atol=1e-07), f"pipeline_output {pipeline_output} is not equal to gather_output {gather_output}"

0 commit comments

Comments
 (0)