Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
627 changes: 0 additions & 627 deletions lmdeploy/pytorch/backends/cuda/moe.py

This file was deleted.

4 changes: 4 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .blocked_fp8 import TritonFusedMoEBlockedF8Builder # noqa: F401
from .default import TritonFusedMoEBuilder # noqa: F401
from .w8a8 import TritonFusedMoEW8A8Builder # noqa: F401
259 changes: 259 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright (c) OpenMMLab. All rights reserved.

import contextlib
from typing import Callable, List

import torch
import torch.distributed as dist

from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp

logger = get_logger('lmdeploy')


class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
"""Triton fused moe blocked f8 implementation."""

def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False,
block_size: int = 128,
out_dtype: torch.dtype = torch.float16):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize
self.block_size = block_size
self.out_dtype = out_dtype

def ep_expert_list(self, world_size: int, rank: int):
"""Experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
gate_up_bias: torch.Tensor = None,
down_bias: torch.Tensor = None,
expert_list: List[int] = None,
act_func: Callable = None):
"""forward."""
input_size = hidden_states.shape
hidden_states = hidden_states.flatten(0, -2)
input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype)

expert_offset = 0
num_experts = None
if expert_list is not None and len(expert_list) != self.num_experts:
expert_offset = expert_list[0]
num_experts = self.num_experts
output = fused_moe_blocked_fp8(input_quant,
input_scale,
gate_up_weights,
gate_up_scale,
down_weights,
down_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
w1_bias=gate_up_bias,
w2_bias=down_bias,
out_dtype=hidden_states.dtype,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize,
act_func=act_func)
output = output.unflatten(0, input_size[:-1])
return output


@contextlib.contextmanager
def monk_deep_gemm():
from dlblas.kernels.fused_moe_v3 import use_deep_gemm
if use_deep_gemm:
yield
return

# patch deep_gemm
import deep_gemm
import dlblas

from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm
func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None)
func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None)
deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked

# patch dlblas
dlblas.kernels.fused_moe_v3.use_deep_gemm = True
dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \
patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous
yield

# unpatch dlblas
dlblas.kernels.fused_moe_v3.use_deep_gemm = False

# unpatch deep_gemm
if func0_ is not None:
deep_gemm.get_col_major_tma_aligned_tensor = func0_
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_
else:
del deep_gemm.get_col_major_tma_aligned_tensor
del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked


class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):

def __init__(self,
ep_size: int,
ep_group: dist.ProcessGroup,
top_k: int,
num_experts: int,
hidden_dim: int,
renormalize: bool = False,
block_size: int = 128,
out_dtype: torch.dtype = torch.bfloat16,
layer_idx: int = 0):
super().__init__(top_k, num_experts, renormalize, block_size, out_dtype)
self.num_experts = num_experts
self.ep_size = ep_size
self.ep_group = ep_group
self.hidden_dim = hidden_dim
self.block_size = block_size
self.out_dtype = out_dtype
self.layer_idx = layer_idx
try:
import deep_gemm # noqa: F401
self.use_deep_gemm = True
except ImportError:
self.use_deep_gemm = False
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')

# pre-allocate buffer
self.fusedmoe_build(True)

def ep_expert_list(self, world_size: int, rank: int):
"""Experts list of current rank."""
if get_dist_manager().current_context().dist_config.enable_eplb:
from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer
phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx)
expert_per_rank = (self.num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, self.num_experts)
sliced_phy2log = phy2log[first_expert:last_expert].tolist()
return sliced_phy2log
else:
return super().ep_expert_list(world_size=world_size, rank=rank)

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
gate_up_bias: torch.Tensor = None,
down_bias: torch.Tensor = None,
expert_list: List[int] = None,
act_func: Callable = None,
**kwargs):
"""forward."""
hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,
topk_ids)

topk_weights = self.do_renormalize(topk_weights)
step_ctx = get_step_ctx_manager().current_context()
low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
moe = self.fusedmoe_build(low_latency_mode)
out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,
down_scale, expert_list)

out_states = gather_outputs_by_attn_tp(out_states, split_size)
return out_states

def do_renormalize(self, topk_weights):
return _renormalize(topk_weights, self.renormalize)

def fusedmoe_build(self, low_latency_mode: bool = False):
from dlblas.layers.moe.ep_moe import build_deepep_moe
deepep_moe = build_deepep_moe(low_latency_mode,
self.ep_size,
self.ep_group,
self.num_experts,
self.hidden_dim,
self.block_size,
self.top_k,
self.out_dtype,
layer_idx=self.layer_idx,
chunk_size=16 * 1024)

# patch forward
_origin_forward = deepep_moe.forward
_origin_fusedmoe_forward = deepep_moe.fusedmoe_forward

def _patched_forward(*args, **kwargs):
with monk_deep_gemm():
out = _origin_forward(*args, **kwargs)
return out

def _patched_fusedmoe_forward(*args, **kwargs):
with monk_deep_gemm():
out = _origin_fusedmoe_forward(*args, **kwargs)
return out

deepep_moe.forward = _patched_forward
deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward

return deepep_moe


class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
"""Triton fused moe blocked f8 builder."""

@staticmethod
def build(top_k: int,
num_experts: int,
hidden_dim: int = 1,
renormalize: bool = False,
block_size: int = 128,
ep_size: int = 1,
ep_group: dist.ProcessGroup = None,
out_dtype: torch.dtype = torch.float16,
layer_idx: int = 0,
custom_gateup_act: bool = False):
"""Build from mlp."""
if ep_size > 1:
assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.'
return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size,
ep_group=ep_group,
top_k=top_k,
num_experts=num_experts,
hidden_dim=hidden_dim,
renormalize=renormalize,
block_size=block_size,
out_dtype=out_dtype,
layer_idx=layer_idx)
else:
return TritonFusedMoEBlockedF8Impl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize,
block_size=block_size,
out_dtype=out_dtype)
Loading