diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index d0e9e663f2..2a7e23b509 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -474,6 +474,14 @@ def __init__(self, except ImportError: self.use_deep_gemm = False logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') + try: + from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep + self.use_deepep = use_deepep + self.deepep_buffer = DeepEPBuffer + self.deepep_mode = DeepEPMode + except ImportError: + self.use_deepep = False + logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP') # pre-allocate buffer self.fusedmoe_build(True) @@ -592,6 +600,14 @@ def _patched_fusedmoe_forward(*args, **kwargs): return deepep_moe + def update_dispatch_mode(self): + if self.use_deepep: + deepep_mode_type = self.deepep_mode.NORMAL + step_ctx = get_step_ctx_manager().current_context() + if step_ctx.is_decoding: + deepep_mode_type = self.deepep_mode.LOW_LATENCY + self.deepep_buffer.set_deepep_mode(deepep_mode_type) + class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): """Triton fused moe blocked f8 builder.""" diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index f0715f993f..67ca3c0324 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1158,6 +1158,11 @@ def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() + def _update_dispatch_mode(self): + if isinstance(self.model.layers[0].mlp, DeepseekV2MoE): + if hasattr(self.model.layers[0].mlp.experts.impl, 'update_dispatch_mode'): + self.model.layers[0].mlp.experts.impl.update_dispatch_mode() + def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], @@ -1169,6 +1174,8 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata + self._update_dispatch_mode() + return dict( input_ids=input_ids, position_ids=position_ids,