Skip to content

Commit 7863313

Browse files
authored
fix dllm long-context (#4012)
* fix dllm long-context * broadcast extra inputs * fix sdar warmup
1 parent 2a826e3 commit 7863313

File tree

5 files changed

+81
-36
lines changed

5 files changed

+81
-36
lines changed

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ async def _async_model_forward(
438438
):
439439
"""Model forward."""
440440
max_prefill_token_num = self.cache_config.max_prefill_token_num
441+
strategy = self.agent_strategy
441442

442443
class _OutputGather:
443444
"""Output gather."""
@@ -469,7 +470,11 @@ def gather(self, output):
469470
def get_output(self):
470471
"""Get tmp_output."""
471472
if not return_logits:
472-
return self._output[:, -1:]
473+
seqlen = torch.full((1, ),
474+
self._output.numel() // self._output.size(-1),
475+
device=self._output.device,
476+
dtype=self._output.dtype)
477+
return strategy.slice_outputs(self._output, seqlen)
473478
torch.cuda.synchronize()
474479
return self._output.to(self._device)
475480

@@ -562,17 +567,14 @@ def _push_output(self, output: BatchedOutputs):
562567
self._out_que.put_nowait((output, event))
563568

564569
@contextmanager
565-
def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True):
570+
def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):
566571
if not enable:
567572
yield
568573
return
569574

570-
if dist_ctx is None:
571-
dist_ctx = get_dist_manager().current_context()
572-
tp_gpu_group = dist_ctx.tp_gpu_group
573-
handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
574-
yield
575-
handle.wait()
575+
dist_ctx = self.dist_ctx
576+
with self.agent_strategy.broadcast_next_token(next_token_ids, extra_inputs, dist_ctx) as handle:
577+
yield handle
576578

577579
async def _async_step_background(
578580
self,
@@ -698,6 +700,7 @@ async def __prepare_dp():
698700
seq_length = output.get('seq_length', inputs.seq_length)
699701
last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob]
700702
extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length)
703+
model_metas = output.get('model_metas')
701704

702705
# output empty for dummy inputs
703706
if is_dummy:
@@ -711,47 +714,40 @@ async def __prepare_dp():
711714
# sampling
712715
next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs)
713716

714-
with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
715-
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')
717+
# post sampling
718+
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
719+
extra_inputs)
716720

717-
# post sampling
718-
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(
719-
inputs, last_logits, next_token_ids, extra_inputs)
721+
with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
722+
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')
720723

721724
# stopping criteria
722725
stopped, stop_pos, stopping_criteria = stopping_criteria.step(next_token_ids,
723726
sampling_inputs.stop_words,
724727
inputs=inputs,
725728
extra_inputs=extra_inputs)
729+
730+
# send output
731+
logger.debug(f'<ForwardTask> rank[{rank}]: Output [{idx}]')
732+
extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
733+
self._push_output(
734+
BatchedOutputs(next_token_ids=next_token_ids,
735+
logits=logits if return_logits else None,
736+
stopped=stopped,
737+
stop_pos=stop_pos,
738+
model_metas=model_metas,
739+
logprobs=logprobs,
740+
extra_outputs=extra_outputs))
726741
else:
727742
# Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
728743
# as it can trigger recompilation on different ranks when using torch.compile.
729-
with torch.inference_mode():
730-
next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0))
731-
logprobs = None
744+
next_token_ids, extra_inputs = self.agent_strategy.make_dummy_next_token(
745+
inputs, last_logits, extra_inputs)
732746

733747
# broadcast next token for TP > 1
734-
with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
748+
with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
735749
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')
736750

737-
# post sampling
738-
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
739-
extra_inputs)
740-
741-
# send output
742-
model_metas = output.get('model_metas')
743-
if need_output:
744-
logger.debug(f'<ForwardTask> rank[{rank}]: Output [{idx}]')
745-
extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
746-
self._push_output(
747-
BatchedOutputs(next_token_ids=next_token_ids,
748-
logits=logits if return_logits else None,
749-
stopped=stopped,
750-
stop_pos=stop_pos,
751-
model_metas=model_metas,
752-
logprobs=logprobs,
753-
extra_outputs=extra_outputs))
754-
755751
# update for next loop
756752
if is_decoding and idx < loop_count - 1:
757753
inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs)

lmdeploy/pytorch/strategies/ar/model_agent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from contextlib import contextmanager
23
from dataclasses import dataclass
34
from typing import Any, List, Optional
45

56
import torch
67
from torch.profiler import record_function
78

9+
import lmdeploy.pytorch.distributed as dist
10+
from lmdeploy.pytorch.distributed import DistContext
811
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
912
from lmdeploy.pytorch.messages import SchedulerSequence
1013
from lmdeploy.pytorch.model_inputs import ModelInputs
@@ -106,3 +109,11 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_
106109
extra_inputs: ARExtraInputs):
107110
"""Post sampling."""
108111
return next_token_ids, extra_inputs
112+
113+
@contextmanager
114+
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: DistContext):
115+
"""Broadcast next token ids and extra inputs."""
116+
tp_gpu_group = dist_ctx.tp_gpu_group
117+
handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
118+
yield
119+
handle.wait()

lmdeploy/pytorch/strategies/base/model_agent.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from abc import ABC, abstractmethod
3+
from contextlib import contextmanager
34
from dataclasses import dataclass, fields
45
from typing import TYPE_CHECKING, Any, List, Optional
56

67
import numpy as np
78
import torch
89

910
if TYPE_CHECKING:
11+
from lmdeploy.pytorch.distributed import DistContext
1012
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
1113
from lmdeploy.pytorch.messages import SchedulerSequence
1214
from lmdeploy.pytorch.model_inputs import ModelInputs
@@ -33,6 +35,10 @@ def to_device(self, device: str, non_blocking: bool = False):
3335
"""To device."""
3436
return to_device(self, device, non_blocking)
3537

38+
def broadcast(self, src: int, group, async_op=False):
39+
"""Broadcast extra inputs."""
40+
pass
41+
3642

3743
@dataclass
3844
class ExtraOutputs(ABC):
@@ -130,3 +136,14 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_
130136
extra_inputs: ExtraInputs):
131137
"""Post sampling."""
132138
pass
139+
140+
def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):
141+
"""Make dummy next token for broadcast."""
142+
with torch.inference_mode():
143+
next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
144+
return next_token_ids, extra_inputs
145+
146+
@abstractmethod
147+
@contextmanager
148+
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: 'DistContext'):
149+
"""Broadcast next token ids and extra inputs."""

lmdeploy/pytorch/strategies/base/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def make_dummy_inputs(batch_size: int,
3636
num_ignored_history=num_ignored_history,
3737
max_q_seqlen=max_q_seqlen,
3838
max_kv_seqlen=max_kv_seqlen,
39-
sum_kv_seqlen=batch_size,
39+
sum_kv_seqlen=num_tokens,
4040
local_adapter_ids=local_adapter_ids,
4141
)
4242

lmdeploy/pytorch/strategies/dllm/model_agent.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from contextlib import contextmanager
23
from dataclasses import dataclass
34
from typing import Any, List, Optional
45

56
import numpy as np
67
import torch
78
from torch.profiler import record_function
89

10+
import lmdeploy.pytorch.distributed as dist
911
from lmdeploy.pytorch import consts
1012
from lmdeploy.pytorch.config import DLLMConfig
13+
from lmdeploy.pytorch.distributed import DistContext
1114
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
1215
from lmdeploy.pytorch.messages import SchedulerSequence
1316
from lmdeploy.pytorch.model_inputs import ModelInputs
@@ -23,6 +26,9 @@ class DLLMExtraInputs(ExtraInputs):
2326
"""DLLM extra inputs."""
2427
dllm_mask: torch.Tensor
2528

29+
def broadcast(self, src: int, group, async_op=False):
30+
return dist.broadcast(self.dllm_mask, src=src, group=group, async_op=async_op)
31+
2632

2733
@dataclass
2834
class DLLMExtraOutputs(ExtraOutputs):
@@ -216,3 +222,18 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_
216222

217223
extra_inputs.dllm_mask = dllm_mask
218224
return next_token_ids, extra_inputs
225+
226+
def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: DLLMExtraInputs):
227+
"""Make dummy next token for broadcast."""
228+
with torch.inference_mode():
229+
next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
230+
return next_token_ids, extra_inputs
231+
232+
@contextmanager
233+
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext):
234+
"""Broadcast next token ids and extra inputs."""
235+
tp_gpu_group = dist_ctx.tp_gpu_group
236+
dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
237+
handle = extra_inputs.broadcast(src=0, group=tp_gpu_group, async_op=True)
238+
yield
239+
handle.wait()

0 commit comments

Comments
 (0)