Skip to content

Commit f63730d

Browse files
authored
Refactor dp tp (#4004)
* WIP * moe * refactor * fix * fix * vis * fix pd * optimize gather * expose layer tp * ep + attn tp allgather * fix not aligned weight * moe microbatch pipeline * fix * fix * patch deep-gemm * fix dummy * fix reduce_scatter * linear reduce scatter * avoid oom * fix linear * fix long context
1 parent 02cd79b commit f63730d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1521
-838
lines changed

lmdeploy/messages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ class PytorchEngineConfig:
292292
session_len (int): Max session length. Default None.
293293
max_batch_size (int): Max batch size. If it is not specified,
294294
the engine will automatically set it according to the device
295+
attn_tp_size (int): tp size for attention, only works for dp>1
296+
mlp_tp_size (int): tp size for mlp, only works for dp>1
297+
moe_tp_size (int): tp size for moe, only works for dp>1
295298
cache_max_entry_count (float): the percentage of gpu memory occupied
296299
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
297300
it defaults to 0.8, signifying the percentage of FREE GPU memory
@@ -353,6 +356,9 @@ class PytorchEngineConfig:
353356
ep: int = 1
354357
session_len: int = None
355358
max_batch_size: int = None
359+
attn_tp_size: int = None
360+
mlp_tp_size: int = None
361+
moe_tp_size: int = None
356362
cache_max_entry_count: float = 0.8
357363
prefill_interval: int = 16
358364
block_size: int = 64

lmdeploy/pytorch/backends/awq_modules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ def update_weights(self,
1717
return qweight, scales, qzeros, bias
1818

1919
@abstractmethod
20-
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
20+
def forward(self,
21+
x,
22+
weight: torch.Tensor,
23+
bias: Optional[torch.Tensor] = None,
24+
all_reduce: bool = False,
25+
group: Optional[torch.distributed.ProcessGroup] = None):
2126
"""forward."""
2227
raise NotImplementedError
2328

lmdeploy/pytorch/backends/blockedf8_modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Optional
44

55
import torch
6+
import torch.distributed as dist
67

78

89
class LinearBlockedF8Impl(ABC):
@@ -19,6 +20,7 @@ def forward(self,
1920
scale: torch.Tensor,
2021
bias: Optional[torch.Tensor] = None,
2122
all_reduce: bool = False,
23+
group: Optional[dist.ProcessGroup] = None,
2224
rank: int = 0,
2325
scatter_size: List[int] = None):
2426
"""forward."""

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
self.flash_attention_fwd = flash_attention_fwd
9191

9292
# for alibi attention
93-
world_size, rank = get_tp_world_rank()
93+
world_size, rank = get_tp_world_rank('attn')
9494
self.alibi_head_offset = self.num_heads * rank
9595
self.alibi_num_heads = self.num_heads * world_size
9696
self.block_sparse_size = block_sparse_size

lmdeploy/pytorch/backends/cuda/awq_modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ def forward(self,
5555
scales: torch.Tensor,
5656
qzeros: torch.Tensor,
5757
bias: Optional[torch.Tensor] = None,
58-
all_reduce: bool = False):
58+
all_reduce: bool = False,
59+
group: Optional[torch.distributed.ProcessGroup] = None):
5960
"""forward."""
6061
out_features = scales.size(1)
6162
out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
6263
if all_reduce:
63-
dist.all_reduce(out)
64+
dist.all_reduce(out, group=group)
6465
return out
6566

6667

lmdeploy/pytorch/backends/cuda/blockedf8_modules.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@
1313
logger = get_logger('lmdeploy')
1414

1515

16-
def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
17-
"""Reduce scatter."""
18-
outs = out.split(tp_sizes, -2)
19-
out = outs[rank]
20-
outs = list(outs)
21-
dist.reduce_scatter(out, outs)
22-
return out
23-
24-
2516
class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
2617
"""Triton linear blocked f8 implementation."""
2718

@@ -37,6 +28,7 @@ def forward(self,
3728
scale: torch.Tensor,
3829
bias: Optional[torch.Tensor] = None,
3930
all_reduce: bool = False,
31+
group: Optional[dist.ProcessGroup] = None,
4032
rank: int = 0,
4133
scatter_size: List[int] = None):
4234
"""forward."""
@@ -52,7 +44,7 @@ def forward(self,
5244

5345
if all_reduce:
5446
if scatter_size is not None:
55-
out = _reduce_scatter_input(out, rank, scatter_size)
47+
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
5648
else:
5749
dist.all_reduce(out)
5850
return out
@@ -117,6 +109,7 @@ def forward(self,
117109
scale: torch.Tensor,
118110
bias: Optional[torch.Tensor] = None,
119111
all_reduce: bool = False,
112+
group: Optional[dist.ProcessGroup] = None,
120113
rank: int = 0,
121114
scatter_size: List[int] = None):
122115
"""forward."""
@@ -128,12 +121,11 @@ def forward(self,
128121
out = out[:x.size(0)]
129122
if bias is not None:
130123
out += bias
124+
out = out.unflatten(0, x_shape[:-1])
131125

132126
if all_reduce:
133127
if scatter_size is not None:
134-
out = _reduce_scatter_input(out, rank, scatter_size)
128+
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
135129
else:
136-
dist.all_reduce(out)
137-
138-
out = out.unflatten(0, x_shape[:-1])
130+
dist.all_reduce(out, group=group)
139131
return out

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def update_inputs(self, inputs):
262262
meta = self.get_meta()
263263
padding_batch_size = meta.padding_batch_size
264264
tp_size = self._get_capture_tokens(padding_batch_size)
265-
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
265+
dp_meta.sync_tp_size(tp_size)
266266
return inputs
267267

268268
def get_capture_batch_sizes(self) -> List[int]:

0 commit comments

Comments
 (0)