Skip to content

Commit c8ce820

Browse files
committed
update_dlblasanddeepep
1 parent febfa9e commit c8ce820

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

lmdeploy/pytorch/backends/cuda/moe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ def __init__(self,
474474
except ImportError:
475475
self.use_deep_gemm = False
476476
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')
477+
try:
478+
import deep_ep
479+
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode
480+
self.use_deepep = True
481+
except ImportError:
482+
self.use_deepep = False
483+
logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')
477484

478485
# pre-allocate buffer
479486
self.fusedmoe_build(True)
@@ -592,6 +599,14 @@ def _patched_fusedmoe_forward(*args, **kwargs):
592599

593600
return deepep_moe
594601

602+
def update_dispatch_mode(self):
603+
if self.use_deepep:
604+
deepep_mode = DeepEPMode.NORMAL
605+
step_ctx = get_step_ctx_manager().current_context()
606+
if step_ctx.is_decoding:
607+
deepep_mode = DeepEPMode.LOW_LATENCY
608+
DeepEPBuffer.set_deepep_mode(deepep_mode)
609+
595610

596611
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
597612
"""Triton fused moe blocked f8 builder."""

lmdeploy/pytorch/models/deepseek_v2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,11 @@ def get_input_embeddings(self):
11581158
"""Get input embeddings."""
11591159
return self.model.get_input_embeddings()
11601160

1161+
def _update_dispatch_mode(self):
1162+
if isinstance(self.model.layers[0].mlp, DeepseekV2MoE):
1163+
if hasattr(self.model.layers[0].mlp.experts.impl, "update_dispatch_mode"):
1164+
self.model.layers[0].mlp.experts.impl.update_dispatch_mode()
1165+
11611166
def prepare_inputs_for_generation(
11621167
self,
11631168
past_key_values: List[List[torch.Tensor]],
@@ -1169,11 +1174,7 @@ def prepare_inputs_for_generation(
11691174
position_ids = context.position_ids
11701175
attn_metadata = context.attn_metadata
11711176

1172-
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode
1173-
deepep_mode = DeepEPMode.NORMAL
1174-
if context.is_decoding:
1175-
deepep_mode = DeepEPMode.LOW_LATENCY
1176-
DeepEPBuffer.set_deepep_mode(deepep_mode)
1177+
self._update_dispatch_mode()
11771178

11781179
return dict(
11791180
input_ids=input_ids,

0 commit comments

Comments
 (0)