Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ def __init__(self,
except ImportError:
self.use_deep_gemm = False
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')
try:
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep
self.use_deepep = use_deepep
self.deepep_buffer = DeepEPBuffer
self.deepep_mode = DeepEPMode
except ImportError:
self.use_deepep = False
logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')

# pre-allocate buffer
self.fusedmoe_build(True)
Expand Down Expand Up @@ -592,6 +600,14 @@ def _patched_fusedmoe_forward(*args, **kwargs):

return deepep_moe

def update_dispatch_mode(self):
if self.use_deepep:
deepep_mode_type = self.deepep_mode.NORMAL
step_ctx = get_step_ctx_manager().current_context()
if step_ctx.is_decoding:
deepep_mode_type = self.deepep_mode.LOW_LATENCY
self.deepep_buffer.set_deepep_mode(deepep_mode_type)


class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
"""Triton fused moe blocked f8 builder."""
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,11 @@ def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()

def _update_dispatch_mode(self):
if isinstance(self.model.layers[0].mlp, DeepseekV2MoE):
if hasattr(self.model.layers[0].mlp.experts.impl, 'update_dispatch_mode'):
self.model.layers[0].mlp.experts.impl.update_dispatch_mode()

def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
Expand All @@ -1169,6 +1174,8 @@ def prepare_inputs_for_generation(
position_ids = context.position_ids
attn_metadata = context.attn_metadata

self._update_dispatch_mode()

return dict(
input_ids=input_ids,
position_ids=position_ids,
Expand Down