File tree Expand file tree Collapse file tree 2 files changed +21
-5
lines changed Expand file tree Collapse file tree 2 files changed +21
-5
lines changed Original file line number Diff line number Diff line change @@ -474,6 +474,13 @@ def __init__(self,
474474 except ImportError :
475475 self .use_deep_gemm = False
476476 logger .warning ('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM' )
477+ try :
478+ import deep_ep
479+ from dlblas .layers .moe .token_dispatcher import DeepEPBuffer , DeepEPMode
480+ self .use_deepep = True
481+ except ImportError :
482+ self .use_deepep = False
483+ logger .warning ('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP' )
477484
478485 # pre-allocate buffer
479486 self .fusedmoe_build (True )
@@ -592,6 +599,14 @@ def _patched_fusedmoe_forward(*args, **kwargs):
592599
593600 return deepep_moe
594601
602+ def update_dispatch_mode (self ):
603+ if self .use_deepep :
604+ deepep_mode = DeepEPMode .NORMAL
605+ step_ctx = get_step_ctx_manager ().current_context ()
606+ if step_ctx .is_decoding :
607+ deepep_mode = DeepEPMode .LOW_LATENCY
608+ DeepEPBuffer .set_deepep_mode (deepep_mode )
609+
595610
596611class TritonFusedMoEBlockedF8Builder (FusedMoEBlockedF8Builder ):
597612 """Triton fused moe blocked f8 builder."""
Original file line number Diff line number Diff line change @@ -1158,6 +1158,11 @@ def get_input_embeddings(self):
11581158 """Get input embeddings."""
11591159 return self .model .get_input_embeddings ()
11601160
1161+ def _update_dispatch_mode (self ):
1162+ if isinstance (self .model .layers [0 ].mlp , DeepseekV2MoE ):
1163+ if hasattr (self .model .layers [0 ].mlp .experts .impl , "update_dispatch_mode" ):
1164+ self .model .layers [0 ].mlp .experts .impl .update_dispatch_mode ()
1165+
11611166 def prepare_inputs_for_generation (
11621167 self ,
11631168 past_key_values : List [List [torch .Tensor ]],
@@ -1169,11 +1174,7 @@ def prepare_inputs_for_generation(
11691174 position_ids = context .position_ids
11701175 attn_metadata = context .attn_metadata
11711176
1172- from dlblas .layers .moe .token_dispatcher import DeepEPBuffer , DeepEPMode
1173- deepep_mode = DeepEPMode .NORMAL
1174- if context .is_decoding :
1175- deepep_mode = DeepEPMode .LOW_LATENCY
1176- DeepEPBuffer .set_deepep_mode (deepep_mode )
1177+ self ._update_dispatch_mode ()
11771178
11781179 return dict (
11791180 input_ids = input_ids ,
You can’t perform that action at this time.
0 commit comments