diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e40839b1d475..a8063f99a6e5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -26,7 +26,6 @@ import typing from collections.abc import Callable, Iterable from typing import Any, Optional, Union - import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config @@ -63,6 +62,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.logger import init_logger +from vllm.forward_context import get_forward_context logger = init_logger(__name__) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) @@ -88,10 +88,64 @@ if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant + import aiter as rocm_aiter rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 rocm_aiter_fp8_quant_group_size = 128 + + def streams_breaks( + hidden_states: torch.Tensor, + shared_output_q: torch.Tensor, + shared_output_s: torch.Tensor, + router_logits: torch.Tensor, + layer_prefix: str, + routed_scaling_factor: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + + #Get the correct deepseek instance + ctx = get_forward_context() + self = ctx.no_compile_layers[layer_prefix] + + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.alt_stream): + shared_output, _ = self.shared_experts.down_proj( + shared_output_q, x_quant_scales=shared_output_s + ) + + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + + current_stream.wait_stream(self.alt_stream) + + final_hidden_states = fused_mul_add( + final_hidden_states, + routed_scaling_factor, shared_output + ) + + return shared_output, final_hidden_states + # Fake impl, only need shape + def streams_breaks_fake( + hidden_states: torch.Tensor, + shared_output_q: torch.Tensor, + shared_output_s: torch.Tensor, + router_logits: torch.Tensor, + layer_prefix: str, + routed_scaling_factor: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + M, H = hidden_states.shape + device = hidden_states.device + dtype = hidden_states.dtype + + shared_out = torch.empty((M, H), device=device, dtype=dtype) + final_out = torch.empty((M, H), device=device, dtype=dtype) + + return shared_out, final_out + def rocm_aiter_triton_fused_shared_expert_impl( hidden_states_shared: torch.Tensor, hidden_states_shared_scale: torch.Tensor, @@ -133,6 +187,15 @@ def rocm_aiter_triton_fused_shared_expert_fake( router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) return shared_output_q, shared_output_s, router_logits + # register as torch.ops.vllm.streams_breaks + direct_register_custom_op( + op_name="streams_breaks", + op_func=streams_breaks, + mutates_args=[], + fake_impl=streams_breaks_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_triton_fused_shared_expert", op_func=rocm_aiter_triton_fused_shared_expert_impl, @@ -200,6 +263,7 @@ def forward(self, x): return x + class DeepseekV2MoE(nn.Module): def __init__( @@ -208,16 +272,18 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - + self.prefix = prefix self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + self.alt_stream = alt_stream if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -292,11 +358,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + + did_fma = False if ( self.use_triton_fused_shared_expert and self.n_shared_experts is not None ): hidden_states_shared, hidden_states_shared_scale = hidden_states_shared + shared_output_q, shared_output_s, router_logits = ( torch.ops.vllm.rocm_aiter_triton_fused_shared_expert( hidden_states_shared=hidden_states_shared, @@ -315,42 +384,54 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ), ) ) - shared_output, _ = self.shared_experts.down_proj( - shared_output_q, x_quant_scales=shared_output_s - ) + + if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #self.alt_stream is not None and + shared_output, final_hidden_states = torch.ops.vllm.streams_breaks( + hidden_states=hidden_states, + shared_output_q=shared_output_q, + shared_output_s=shared_output_s, + router_logits=router_logits, + layer_prefix=self.prefix, + routed_scaling_factor=self.routed_scaling_factor, + ) + + did_fma = True else: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states_shared) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None: - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output) - else: - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. + if not did_fma: + if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None: final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if shared_output is not None: + final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output) + else: if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor else: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + if hidden_states.dtype != torch.float16: + final_hidden_states = final_hidden_states + shared_output + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + self.experts.maybe_all_reduce_tensor_model_parallel( ## + final_hidden_states + ) + ) return final_hidden_states.view(num_tokens, hidden_dim) @@ -721,6 +802,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, enable_eplb: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -765,6 +847,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", enable_eplb=enable_eplb, + alt_stream=alt_stream ) else: self.mlp = DeepseekV2MLP( @@ -774,6 +857,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + + if isinstance(self.mlp, DeepseekV2MoE): + compilation_config = get_current_vllm_config().compilation_config + name = self.mlp.prefix + if name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name in static_forward_context: {name}") + compilation_config.static_forward_context[name] = self.mlp + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -883,6 +974,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.embed_tokens = PPMissingLayer() + self.alt_stream = torch.cuda.Stream() if torch.cuda.is_available() else None self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( @@ -892,9 +984,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, enable_eplb=enable_eplb, + alt_stream=self.alt_stream, ), prefix=f"{prefix}.layers") + self._moe_modules = [] + for layer in self.layers: + mlp = getattr(layer, "mlp", None) + if isinstance(mlp, DeepseekV2MoE): + self._moe_modules.append(mlp) + if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -913,6 +1012,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds