Skip to content
146 changes: 123 additions & 23 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union

import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
Expand Down Expand Up @@ -63,6 +62,7 @@
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.logger import init_logger
from vllm.forward_context import get_forward_context
logger = init_logger(__name__)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
Expand All @@ -88,10 +88,64 @@
if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS:
from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16
from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant

import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def streams_breaks(
hidden_states: torch.Tensor,
shared_output_q: torch.Tensor,
shared_output_s: torch.Tensor,
router_logits: torch.Tensor,
layer_prefix: str,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:

#Get the correct deepseek instance
ctx = get_forward_context()
self = ctx.no_compile_layers[layer_prefix]

current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)

with torch.cuda.stream(self.alt_stream):
shared_output, _ = self.shared_experts.down_proj(
shared_output_q, x_quant_scales=shared_output_s
)

final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)

current_stream.wait_stream(self.alt_stream)

final_hidden_states = fused_mul_add(
final_hidden_states,
routed_scaling_factor, shared_output
)

return shared_output, final_hidden_states

# Fake impl, only need shape
def streams_breaks_fake(
hidden_states: torch.Tensor,
shared_output_q: torch.Tensor,
shared_output_s: torch.Tensor,
router_logits: torch.Tensor,
layer_prefix: str,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
M, H = hidden_states.shape
device = hidden_states.device
dtype = hidden_states.dtype

shared_out = torch.empty((M, H), device=device, dtype=dtype)
final_out = torch.empty((M, H), device=device, dtype=dtype)

return shared_out, final_out

def rocm_aiter_triton_fused_shared_expert_impl(
hidden_states_shared: torch.Tensor,
hidden_states_shared_scale: torch.Tensor,
Expand Down Expand Up @@ -133,6 +187,15 @@ def rocm_aiter_triton_fused_shared_expert_fake(
router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device)
return shared_output_q, shared_output_s, router_logits

# register as torch.ops.vllm.streams_breaks
direct_register_custom_op(
op_name="streams_breaks",
op_func=streams_breaks,
mutates_args=[],
fake_impl=streams_breaks_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_triton_fused_shared_expert",
op_func=rocm_aiter_triton_fused_shared_expert_impl,
Expand Down Expand Up @@ -200,6 +263,7 @@ def forward(self, x):
return x



class DeepseekV2MoE(nn.Module):

def __init__(
Expand All @@ -208,16 +272,18 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor

self.prefix = prefix
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
self.alt_stream = alt_stream

if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
Expand Down Expand Up @@ -292,11 +358,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

did_fma = False
if (
self.use_triton_fused_shared_expert
and self.n_shared_experts is not None
):
hidden_states_shared, hidden_states_shared_scale = hidden_states_shared

shared_output_q, shared_output_s, router_logits = (
torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
hidden_states_shared=hidden_states_shared,
Expand All @@ -315,42 +384,54 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
),
)
)
shared_output, _ = self.shared_experts.down_proj(
shared_output_q, x_quant_scales=shared_output_s
)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16: #self.alt_stream is not None and
shared_output, final_hidden_states = torch.ops.vllm.streams_breaks(
hidden_states=hidden_states,
shared_output_q=shared_output_q,
shared_output_s=shared_output_s,
router_logits=router_logits,
layer_prefix=self.prefix,
routed_scaling_factor=self.routed_scaling_factor,
)

did_fma = True
else:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if not did_fma:
if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = fused_mul_add(final_hidden_states, self.routed_scaling_factor, shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)

if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
self.experts.maybe_all_reduce_tensor_model_parallel( ##
final_hidden_states
)
)

return final_hidden_states.view(num_tokens, hidden_dim)

Expand Down Expand Up @@ -721,6 +802,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -765,6 +847,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
alt_stream=alt_stream
)
else:
self.mlp = DeepseekV2MLP(
Expand All @@ -774,6 +857,14 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)

if isinstance(self.mlp, DeepseekV2MoE):
compilation_config = get_current_vllm_config().compilation_config
name = self.mlp.prefix
if name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name in static_forward_context: {name}")
compilation_config.static_forward_context[name] = self.mlp

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -883,6 +974,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.embed_tokens = PPMissingLayer()

self.alt_stream = torch.cuda.Stream() if torch.cuda.is_available() else None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
Expand All @@ -892,9 +984,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
alt_stream=self.alt_stream,
),
prefix=f"{prefix}.layers")

self._moe_modules = []
for layer in self.layers:
mlp = getattr(layer, "mlp", None)
if isinstance(mlp, DeepseekV2MoE):
self._moe_modules.append(mlp)

if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
Expand All @@ -913,6 +1012,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand Down
Loading