diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py deleted file mode 100644 index d0e9e663f2..0000000000 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ /dev/null @@ -1,627 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import contextlib -from typing import Callable, List - -import torch -import torch.distributed as dist - -from lmdeploy.pytorch.distributed import get_dist_manager -from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 -from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 -from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 -from lmdeploy.pytorch.kernels.cuda.ep_moe import (grouped_gemm_triton, silu_and_mul_masked_post_quant_fwd, - silu_and_mul_triton_kernel) -from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize -from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8 -from lmdeploy.pytorch.model_inputs import get_step_ctx_manager -from lmdeploy.pytorch.models.q_modules import QTensor -from lmdeploy.utils import get_logger - -from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, - FusedMoEW8A8Impl) - -logger = get_logger('lmdeploy') - - -class TritonFusedMoEImpl(FusedMoEImpl): - """Triton fused moe implementation.""" - - def __init__(self, top_k: int, num_experts: int, renormalize: bool = False): - self.num_experts = num_experts - self.top_k = top_k - self.renormalize = renormalize - - def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor): - gate_up_weights = gate_up_weights.transpose(1, 2).contiguous().transpose(1, 2) - down_weights = down_weights.transpose(1, 2).contiguous().transpose(1, 2) - return gate_up_weights, down_weights - - def support_ep(self): - """Support expert parallelism.""" - return True - - def ep_expert_list(self, world_size: int, rank: int): - """Experts list of current rank.""" - num_experts = self.num_experts - expert_per_rank = (num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, num_experts) - return list(range(first_expert, last_expert)) - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, - gate_up_bias: torch.Tensor = None, - down_bias: torch.Tensor = None, - expert_list: List[int] = None, - act_func: Callable = None): - """forward.""" - expert_offset = 0 - num_experts = None - if expert_list is not None and len(expert_list) != self.num_experts: - expert_offset = expert_list[0] - num_experts = self.num_experts - return fused_moe(hidden_states, - gate_up_weights, - down_weights, - topk_weights=topk_weights, - topk_ids=topk_ids, - topk=self.top_k, - w1_bias=gate_up_bias, - w2_bias=down_bias, - expert_offset=expert_offset, - num_experts=num_experts, - renormalize=self.renormalize, - act_func=act_func) - - -class TritonFusedMoEBuilder(FusedMoEBuilder): - """Triton fused moe builder.""" - - @staticmethod - def build(top_k: int, num_experts: int, renormalize: bool = False): - """Build from mlp.""" - return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize) - - -class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl): - """Triton fused moe w8a8 implementation.""" - - def __init__( - self, - top_k: int, - num_experts: int, - renormalize: bool = False, - out_dtype: torch.dtype = torch.float16, - quant_dtype: torch.dtype = torch.int8, - ): - self.num_experts = num_experts - self.top_k = top_k - self.renormalize = renormalize - self.out_dtype = out_dtype - self.quant_dtype = quant_dtype - - def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor, - down_scale: torch.Tensor): - # do not transpose weight for int8/fp8 - return gate_up_weights, down_weights, gate_up_scale, down_scale - - def support_ep(self): - """Support expert parallelism.""" - return True - - def ep_expert_list(self, world_size: int, rank: int): - """Experts list of current rank.""" - num_experts = self.num_experts - expert_per_rank = (num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, num_experts) - return list(range(first_expert, last_expert)) - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - gate_up_weights: torch.Tensor, - gate_up_scale: torch.Tensor, - down_weights: torch.Tensor, - down_scale: torch.Tensor, - expert_list: List[int] = None): - """forward.""" - - if isinstance(hidden_states, torch.Tensor): - hidden_states = hidden_states.contiguous() - input_quant, input_scale = per_token_quant_int8(hidden_states, 1e-7, quant_dtype=self.quant_dtype) - else: - assert isinstance(hidden_states, QTensor) - input_quant, input_scale = (hidden_states.tensor, hidden_states.scale) - - expert_offset = 0 - num_experts = None - if expert_list is not None and len(expert_list) != self.num_experts: - expert_offset = expert_list[0] - num_experts = self.num_experts - return fused_moe_w8a8(input_quant, - input_scale, - gate_up_weights, - gate_up_scale, - down_weights, - down_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - topk=self.top_k, - out_dtype=self.out_dtype, - quant_dtype=self.quant_dtype, - expert_offset=expert_offset, - num_experts=num_experts, - renormalize=self.renormalize) - - -class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder): - """Triton fused moe w8a8 builder.""" - - @staticmethod - def build( - top_k: int, - num_experts: int, - renormalize: bool = False, - out_dtype: torch.dtype = torch.float16, - quant_dtype: torch.dtype = torch.int8, - ): - """Build from mlp.""" - return TritonFusedMoEW8A8Impl(top_k=top_k, - num_experts=num_experts, - renormalize=renormalize, - out_dtype=out_dtype, - quant_dtype=quant_dtype) - - -class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): - """Triton fused moe blocked f8 implementation.""" - - def __init__(self, - top_k: int, - num_experts: int, - renormalize: bool = False, - block_size: int = 128, - out_dtype: torch.dtype = torch.float16): - self.num_experts = num_experts - self.top_k = top_k - self.renormalize = renormalize - self.block_size = block_size - self.out_dtype = out_dtype - - def support_ep(self): - """Support expert parallelism.""" - return True - - def ep_expert_list(self, world_size: int, rank: int): - """Experts list of current rank.""" - num_experts = self.num_experts - expert_per_rank = (num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, num_experts) - return list(range(first_expert, last_expert)) - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - gate_up_weights: torch.Tensor, - gate_up_scale: torch.Tensor, - down_weights: torch.Tensor, - down_scale: torch.Tensor, - gate_up_bias: torch.Tensor = None, - down_bias: torch.Tensor = None, - expert_list: List[int] = None, - act_func: Callable = None): - """forward.""" - input_size = hidden_states.shape - hidden_states = hidden_states.flatten(0, -2) - input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype) - - expert_offset = 0 - num_experts = None - if expert_list is not None and len(expert_list) != self.num_experts: - expert_offset = expert_list[0] - num_experts = self.num_experts - output = fused_moe_blocked_fp8(input_quant, - input_scale, - gate_up_weights, - gate_up_scale, - down_weights, - down_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - topk=self.top_k, - w1_bias=gate_up_bias, - w2_bias=down_bias, - out_dtype=hidden_states.dtype, - expert_offset=expert_offset, - num_experts=num_experts, - renormalize=self.renormalize, - act_func=act_func) - output = output.unflatten(0, input_size[:-1]) - return output - - -class DeepEPExpertsGroupedGEMM: - """MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek- - ai/DeepEP/tree/main)""" - - def __init__( - self, - num_experts: int, - ep_size: int, - block_shape: list[int], - ): - self.num_experts = num_experts - self.ep_size = ep_size - assert self.num_experts % self.ep_size == 0 - self.num_experts_per_partition = self.num_experts // self.ep_size - self.block_shape = block_shape - self.use_fp8_w8a8 = True - - def forward(self, hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor, gate_up_weight: torch.Tensor, - gate_up_scale: torch.Tensor, gate_down_weight: torch.Tensor, gate_down_scale: torch.Tensor): - seg_indptr_cur_rank = torch.cat([ - torch.zeros(1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype), - torch.cumsum(tokens_per_expert, dim=0), - ]) - reorder_topk_ids = torch.repeat_interleave(tokens_per_expert) - weight_indices_cur_rank = torch.arange( - 0, - self.num_experts_per_partition, - device=hidden_states.device, - dtype=torch.int64, - ) - - # GroupGemm-0 - gateup_output = torch.empty( - hidden_states.shape[0], - gate_up_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if hidden_states.shape[0] > 0: - input, input_scale = quant_fp8(hidden_states, 128, dtype=gate_up_weight.dtype) - gateup_output = grouped_gemm_triton( - a=input, - b=gate_up_weight, - c=gateup_output, - batch_size=self.num_experts_per_partition, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=input_scale, - scale_b=gate_up_scale, - block_shape=self.block_shape, - ) - - # Act - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=hidden_states.dtype, - ) - silu_and_mul_triton_kernel[(gateup_output.shape[0], )]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - None, - 0, - self.num_experts_per_partition - 1, - BLOCK_SIZE=512, - ) - - # GroupGemm-1 - down_output = torch.empty( - down_input.shape[0], - gate_down_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if down_input.shape[0] > 0: - down_input, down_input_scale = quant_fp8(down_input, 128, dtype=gate_down_weight.dtype) - down_output = grouped_gemm_triton( - a=down_input, - b=gate_down_weight, - c=down_output, - batch_size=self.num_experts_per_partition, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=down_input_scale, - scale_b=gate_down_scale, - block_shape=self.block_shape, - ) - return down_output - - -class DeepEPExpertsDeepGEMM: - deep_gemm = None - - def __init__(self, num_experts: int, ep_size: int, block_size: int, out_dtype: torch.dtype = torch.bfloat16): - self.num_experts = num_experts - self.ep_size = ep_size - self.num_experts_per_partition = self.num_experts // self.ep_size - self.block_size = block_size - self.use_fp8_w8a8 = True - self.out_dtype = out_dtype - - def forward( - self, - hidden_states_fp8, - gate_up_weight: torch.Tensor, - gate_up_scale: torch.Tensor, - gate_down_weight: torch.Tensor, - gate_down_scale: torch.Tensor, - masked_m: torch.Tensor, - expected_m: int, - ): - - gate_up_weight_fp8 = (gate_up_weight, gate_up_scale) - gate_down_weight_fp8 = (gate_down_weight, gate_down_scale) - assert (hidden_states_fp8[0].size(0) % 4 == 0), f'TMA alignment error: {hidden_states_fp8[0].size(0)}' - num_groups, m, k = hidden_states_fp8[0].size() - n = gate_up_weight.size(1) - expected_m = min(expected_m, m) - gateup_output = torch.empty((num_groups, m, n), device=hidden_states_fp8[0].device, dtype=self.out_dtype) - DeepEPExpertsDeepGEMM.deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(hidden_states_fp8, gate_up_weight_fp8, - gateup_output, masked_m, expected_m) - down_input = torch.empty(( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2, - ), - device=gateup_output.device, - dtype=gate_down_weight.dtype) - - down_input_scale = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2 // self.block_size, - ), - device=gateup_output.device, - dtype=torch.float32, - ) - silu_and_mul_masked_post_quant_fwd( - gateup_output, - down_input, - down_input_scale, - self.block_size, - masked_m, - ) - n = gate_down_weight.size(1) - down_input_fp8 = ( - down_input, - DeepEPExpertsDeepGEMM.deep_gemm.get_col_major_tma_aligned_tensor(down_input_scale), - ) - down_output = torch.empty((num_groups, m, n), device=down_input.device, dtype=self.out_dtype) - DeepEPExpertsDeepGEMM.deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(down_input_fp8, gate_down_weight_fp8, - down_output, masked_m, expected_m) - return down_output - - -@contextlib.contextmanager -def monk_deep_gemm(): - from dlblas.kernels.fused_moe_v3 import use_deep_gemm - if use_deep_gemm: - yield - return - - # patch deep_gemm - import deep_gemm - import dlblas - - from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm - func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) - func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) - deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked - - # patch dlblas - dlblas.kernels.fused_moe_v3.use_deep_gemm = True - dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ - patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous - yield - - # unpatch dlblas - dlblas.kernels.fused_moe_v3.use_deep_gemm = False - - # unpatch deep_gemm - if func0_ is not None: - deep_gemm.get_col_major_tma_aligned_tensor = func0_ - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ - else: - del deep_gemm.get_col_major_tma_aligned_tensor - del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked - - -class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): - - def __init__(self, - ep_size: int, - ep_group: dist.ProcessGroup, - top_k: int, - num_experts: int, - hidden_dim: int, - renormalize: bool = False, - block_size: int = 128, - out_dtype: torch.dtype = torch.bfloat16, - layer_idx: int = 0): - super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) - self.num_experts = num_experts - self.ep_size = ep_size - self.ep_group = ep_group - self.hidden_dim = hidden_dim - self.block_size = block_size - self.out_dtype = out_dtype - self.layer_idx = layer_idx - try: - import deep_gemm - DeepEPExpertsDeepGEMM.deep_gemm = deep_gemm - self.use_deep_gemm = True - except ImportError: - self.use_deep_gemm = False - logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') - - # pre-allocate buffer - self.fusedmoe_build(True) - - def ep_expert_list(self, world_size: int, rank: int): - """Experts list of current rank.""" - if get_dist_manager().current_context().dist_config.enable_eplb: - from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer - phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx) - expert_per_rank = (self.num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, self.num_experts) - sliced_phy2log = phy2log[first_expert:last_expert].tolist() - return sliced_phy2log - else: - return super().ep_expert_list(world_size=world_size, rank=rank) - - def _split_inputs_by_attn_tp( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - ): - """Split input by attn tp.""" - dist_ctx = get_dist_manager().current_context() - attn_tp = dist_ctx.dist_config.attn_tp - attn_rank = dist_ctx.attn_tp_group.rank - num_states = hidden_states.size(0) - - if attn_tp == 1 or attn_tp > num_states: - return hidden_states, topk_weights, topk_ids, None - - # split size - base = num_states // attn_tp - remain = num_states % attn_tp - split_size = [base + 1] * remain + [base] * (attn_tp - remain) - - # split inputs - hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank] - topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank] - topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank] - - return hidden_states, topk_weights, topk_ids, split_size - - def _gather_outputs_by_attn_tp(self, out_states: torch.Tensor, split_size: List[int]): - """Gather output by attn tp.""" - if split_size is None: - return out_states - - dist_ctx = get_dist_manager().current_context() - gpu_group = dist_ctx.attn_tp_group.gpu_group - new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1])) - new_out_states_list = list(new_out_states.split(split_size, dim=0)) - dist.all_gather(new_out_states_list, out_states, group=gpu_group) - return new_out_states - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - gate_up_weights: torch.Tensor, - gate_up_scale: torch.Tensor, - down_weights: torch.Tensor, - down_scale: torch.Tensor, - gate_up_bias: torch.Tensor = None, - down_bias: torch.Tensor = None, - expert_list: List[int] = None, - act_func: Callable = None, - **kwargs): - """forward.""" - hidden_states, topk_weights, topk_ids, split_size = self._split_inputs_by_attn_tp( - hidden_states, topk_weights, topk_ids) - - topk_weights = self.do_renormalize(topk_weights) - step_ctx = get_step_ctx_manager().current_context() - low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm - moe = self.fusedmoe_build(low_latency_mode) - out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, - down_scale, expert_list) - - out_states = self._gather_outputs_by_attn_tp(out_states, split_size) - return out_states - - def do_renormalize(self, topk_weights): - return _renormalize(topk_weights, self.renormalize) - - def fusedmoe_build(self, low_latency_mode: bool = False): - from dlblas.layers.moe.ep_moe import build_deepep_moe - deepep_moe = build_deepep_moe(low_latency_mode, - self.ep_size, - self.ep_group, - self.num_experts, - self.hidden_dim, - self.block_size, - self.top_k, - self.out_dtype, - layer_idx=self.layer_idx, - chunk_size=16 * 1024) - - # patch forward - _origin_forward = deepep_moe.forward - _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward - - def _patched_forward(*args, **kwargs): - with monk_deep_gemm(): - out = _origin_forward(*args, **kwargs) - return out - - def _patched_fusedmoe_forward(*args, **kwargs): - with monk_deep_gemm(): - out = _origin_fusedmoe_forward(*args, **kwargs) - return out - - deepep_moe.forward = _patched_forward - deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward - - return deepep_moe - - -class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): - """Triton fused moe blocked f8 builder.""" - - @staticmethod - def build(top_k: int, - num_experts: int, - hidden_dim: int = 1, - renormalize: bool = False, - block_size: int = 128, - ep_size: int = 1, - ep_group: dist.ProcessGroup = None, - out_dtype: torch.dtype = torch.float16, - layer_idx: int = 0, - custom_gateup_act: bool = False): - """Build from mlp.""" - if ep_size > 1: - assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.' - return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size, - ep_group=ep_group, - top_k=top_k, - num_experts=num_experts, - hidden_dim=hidden_dim, - renormalize=renormalize, - block_size=block_size, - out_dtype=out_dtype, - layer_idx=layer_idx) - else: - return TritonFusedMoEBlockedF8Impl(top_k=top_k, - num_experts=num_experts, - renormalize=renormalize, - block_size=block_size, - out_dtype=out_dtype) diff --git a/lmdeploy/pytorch/backends/cuda/moe/__init__.py b/lmdeploy/pytorch/backends/cuda/moe/__init__.py new file mode 100644 index 0000000000..882a9ec54c --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blocked_fp8 import TritonFusedMoEBlockedF8Builder # noqa: F401 +from .default import TritonFusedMoEBuilder # noqa: F401 +from .w8a8 import TritonFusedMoEW8A8Builder # noqa: F401 diff --git a/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py new file mode 100644 index 0000000000..d1bd0caf41 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py @@ -0,0 +1,259 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import contextlib +from typing import Callable, List + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl +from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 +from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager +from lmdeploy.utils import get_logger + +from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp + +logger = get_logger('lmdeploy') + + +class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): + """Triton fused moe blocked f8 implementation.""" + + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.float16): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + self.block_size = block_size + self.out_dtype = out_dtype + + def ep_expert_list(self, world_size: int, rank: int): + """Experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + gate_up_bias: torch.Tensor = None, + down_bias: torch.Tensor = None, + expert_list: List[int] = None, + act_func: Callable = None): + """forward.""" + input_size = hidden_states.shape + hidden_states = hidden_states.flatten(0, -2) + input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype) + + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + output = fused_moe_blocked_fp8(input_quant, + input_scale, + gate_up_weights, + gate_up_scale, + down_weights, + down_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + w1_bias=gate_up_bias, + w2_bias=down_bias, + out_dtype=hidden_states.dtype, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize, + act_func=act_func) + output = output.unflatten(0, input_size[:-1]) + return output + + +@contextlib.contextmanager +def monk_deep_gemm(): + from dlblas.kernels.fused_moe_v3 import use_deep_gemm + if use_deep_gemm: + yield + return + + # patch deep_gemm + import deep_gemm + import dlblas + + from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm + func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) + func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) + deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked + + # patch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = True + dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ + patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous + yield + + # unpatch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = False + + # unpatch deep_gemm + if func0_ is not None: + deep_gemm.get_col_major_tma_aligned_tensor = func0_ + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ + else: + del deep_gemm.get_col_major_tma_aligned_tensor + del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked + + +class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): + + def __init__(self, + ep_size: int, + ep_group: dist.ProcessGroup, + top_k: int, + num_experts: int, + hidden_dim: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.bfloat16, + layer_idx: int = 0): + super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) + self.num_experts = num_experts + self.ep_size = ep_size + self.ep_group = ep_group + self.hidden_dim = hidden_dim + self.block_size = block_size + self.out_dtype = out_dtype + self.layer_idx = layer_idx + try: + import deep_gemm # noqa: F401 + self.use_deep_gemm = True + except ImportError: + self.use_deep_gemm = False + logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') + + # pre-allocate buffer + self.fusedmoe_build(True) + + def ep_expert_list(self, world_size: int, rank: int): + """Experts list of current rank.""" + if get_dist_manager().current_context().dist_config.enable_eplb: + from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer + phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx) + expert_per_rank = (self.num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, self.num_experts) + sliced_phy2log = phy2log[first_expert:last_expert].tolist() + return sliced_phy2log + else: + return super().ep_expert_list(world_size=world_size, rank=rank) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + gate_up_bias: torch.Tensor = None, + down_bias: torch.Tensor = None, + expert_list: List[int] = None, + act_func: Callable = None, + **kwargs): + """forward.""" + hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights, + topk_ids) + + topk_weights = self.do_renormalize(topk_weights) + step_ctx = get_step_ctx_manager().current_context() + low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm + moe = self.fusedmoe_build(low_latency_mode) + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, + down_scale, expert_list) + + out_states = gather_outputs_by_attn_tp(out_states, split_size) + return out_states + + def do_renormalize(self, topk_weights): + return _renormalize(topk_weights, self.renormalize) + + def fusedmoe_build(self, low_latency_mode: bool = False): + from dlblas.layers.moe.ep_moe import build_deepep_moe + deepep_moe = build_deepep_moe(low_latency_mode, + self.ep_size, + self.ep_group, + self.num_experts, + self.hidden_dim, + self.block_size, + self.top_k, + self.out_dtype, + layer_idx=self.layer_idx, + chunk_size=16 * 1024) + + # patch forward + _origin_forward = deepep_moe.forward + _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward + + def _patched_forward(*args, **kwargs): + with monk_deep_gemm(): + out = _origin_forward(*args, **kwargs) + return out + + def _patched_fusedmoe_forward(*args, **kwargs): + with monk_deep_gemm(): + out = _origin_fusedmoe_forward(*args, **kwargs) + return out + + deepep_moe.forward = _patched_forward + deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward + + return deepep_moe + + +class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): + """Triton fused moe blocked f8 builder.""" + + @staticmethod + def build(top_k: int, + num_experts: int, + hidden_dim: int = 1, + renormalize: bool = False, + block_size: int = 128, + ep_size: int = 1, + ep_group: dist.ProcessGroup = None, + out_dtype: torch.dtype = torch.float16, + layer_idx: int = 0, + custom_gateup_act: bool = False): + """Build from mlp.""" + if ep_size > 1: + assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.' + return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size, + ep_group=ep_group, + top_k=top_k, + num_experts=num_experts, + hidden_dim=hidden_dim, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype, + layer_idx=layer_idx) + else: + return TritonFusedMoEBlockedF8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) diff --git a/lmdeploy/pytorch/backends/cuda/moe/default.py b/lmdeploy/pytorch/backends/cuda/moe/default.py new file mode 100644 index 0000000000..2421df0f04 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe/default.py @@ -0,0 +1,473 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Callable, List, Optional + +import torch + +import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.backends.moe import FusedMoEBuilder, FusedMoEImpl +from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.kernels.cuda import fused_moe +from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager +from lmdeploy.utils import get_logger + +from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp + +logger = get_logger('lmdeploy') + + +class TritonFusedMoEImpl(FusedMoEImpl): + """Triton fused moe implementation.""" + + def __init__(self, top_k: int, num_experts: int, renormalize: bool = False): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + + def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor): + gate_up_weights = gate_up_weights.transpose(1, 2).contiguous().transpose(1, 2) + down_weights = down_weights.transpose(1, 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights + + def ep_expert_list(self, world_size: int, rank: int): + """Experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + gate_up_bias: torch.Tensor = None, + down_bias: torch.Tensor = None, + expert_list: List[int] = None, + act_func: Callable = None): + """forward.""" + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + return fused_moe(hidden_states, + gate_up_weights, + down_weights, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + w1_bias=gate_up_bias, + w2_bias=down_bias, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize, + act_func=act_func) + + +# modify from dlblas: https://github.com/DeepLink-org/DLBlas +class FusedMoENormal: + + def __init__( + self, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + layer_index: int = 0, + top_k: int = 8, + chunk_size: Optional[int] = 32 * 1024, + out_dtype: torch.dtype = torch.bfloat16, + ): + from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherNormal + self.layer_index = layer_index + self.top_k = top_k + self.num_experts = num_experts + self.num_local_experts = num_experts // ep_size + self.out_dtype = out_dtype + self.token_dispatcher = DeepEPTokenDispatcherNormal( + group=ep_group, + num_experts=num_experts, + num_local_experts=self.num_local_experts, + hidden_size=hidden_dim, + params_dtype=out_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + up_weights: torch.Tensor, + down_weights: torch.Tensor, + expert_list: List[int] = None, + ): + """forward.""" + from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3 + x, recv_topk_ids, recv_topk_weights, recv_tokens_per_expert = self.token_dispatcher.dispatch( + hidden_states, + topk_ids, + topk_weights, + expert_list, + ) + topk_ids, topk_weights = None, None + out_states = fused_moe_v3(x, recv_topk_ids, recv_topk_weights, up_weights, down_weights, recv_tokens_per_expert) + out_states = self.token_dispatcher.combine(out_states) + return out_states + + def capture(self): + return self.token_dispatcher.buffer_normal.capture() + + def wait(self, event): + self.token_dispatcher.release() + event.current_stream_wait() + + def dispatch_async(self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: Optional[int] = None, + previous_event=None, + async_finish=True): + return self.token_dispatcher.dispatch_normal_async(x, topk_idx, topk_weights, num_experts, previous_event, + async_finish) + + def combine_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True): + return self.token_dispatcher.combine_normal_async(x, handle, previous_event, async_finish) + + def release(self): + return self.token_dispatcher.release() + + def fusedmoe_forward(self, state, up_weight, down_weight): + from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3 + return fused_moe_v3(state['recv_hidden_states'], state['recv_topk_idx'], state['recv_topk_weights'], up_weight, + down_weight, state['recv_tokens_per_expert']) + + +def _disposible_tensor(tensor): + from dlblas.utils.utils import DisposibleTensor + if isinstance(tensor, torch.Tensor): + tensor = DisposibleTensor(tensor) + else: + tensor = [DisposibleTensor(x) for x in tensor] + return tensor + + +def dispatch_ll( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + use_fp8: bool = True, +): + """Dispatch low latency.""" + if num_experts is not None and self.num_experts is not None: + assert self.num_experts == num_experts + topk_idx = topk_idx.to(torch.int64) + expected_m = (hidden_states.shape[0] * self.get_buffer().group_size * topk_idx.shape[1] + + num_experts) // num_experts + + ( + packed_recv_hidden, + masked_m, + self.handle, + event, + hook, + ) = self.get_buffer().low_latency_dispatch( + hidden_states, + topk_idx, + self.num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8=use_fp8, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + ) + hook() if self.return_recv_hook else event.current_stream_wait() + packed_recv_hidden = _disposible_tensor(packed_recv_hidden) + return ( + packed_recv_hidden, + topk_idx, + topk_weights, + masked_m, + expected_m, + ) + + +def dispatch_async_ll( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: Optional[int] = None, + use_fp8: bool = True, + async_finish: bool = True, +): + assert topk_idx.dtype == torch.int64 + if num_experts is not None and self.num_experts is not None: + assert self.num_experts == num_experts + ( + recv_hidden_states, + recv_expert_count, + handle, + event, + hook, + ) = self.get_buffer().low_latency_dispatch( + hidden_states, + topk_idx, + self.num_max_dispatch_tokens_per_rank, + num_experts=self.num_experts, + use_fp8=use_fp8, + async_finish=async_finish, + return_recv_hook=not async_finish, + ) + recv_hidden_states = _disposible_tensor(recv_hidden_states) + return recv_hidden_states, recv_expert_count, handle, event, hook + + +class FusedMoELowLatency: + + def __init__( + self, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + layer_index: int, + out_dtype: torch.dtype = torch.bfloat16, + ): + from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherLowLatency + self.num_experts = num_experts + self.layer_index = layer_index + self.out_dtype = out_dtype + self.token_dispatcher = DeepEPTokenDispatcherLowLatency( + group=ep_group, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + hidden_size=hidden_dim, + params_dtype=out_dtype, + ) + + def experts( + self, + hidden_states: torch.Tensor, + gate_up_weight: torch.Tensor, + gate_down_weight: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + ): + from dlblas.utils.utils import DisposibleTensor + + from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul + from lmdeploy.pytorch.third_party.deep_gemm import m_grouped_bf16_gemm_nt_masked + num_groups, m, _ = hidden_states.shape + n = gate_up_weight.size(1) + expected_m = min(expected_m, m) + gateup_output = gate_up_weight.new_empty((num_groups, m, n)) + m_grouped_bf16_gemm_nt_masked(DisposibleTensor.maybe_unwrap(hidden_states), gate_up_weight, gateup_output, + masked_m, expected_m) + DisposibleTensor.maybe_dispose(hidden_states) + down_input = silu_and_mul(gateup_output.flatten(0, -2)) + down_input = down_input.view( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2, + ) + del gateup_output + n = gate_down_weight.size(1) + down_output = down_input.new_empty((num_groups, m, n)) + m_grouped_bf16_gemm_nt_masked(down_input, gate_down_weight, down_output, masked_m, expected_m) + return down_output + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + up_weights: torch.Tensor, + down_weights: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = dispatch_ll( + self.token_dispatcher, + hidden_states, + topk_ids, + topk_weights, + self.num_experts, + use_fp8=False, + ) + hidden_states = None + out_states = self.experts(recv_hidden_states, up_weights, down_weights, masked_m, expected_m) + out_states = self.token_dispatcher.combine(out_states, topk_idx, topk_weights) + return out_states + + def wait(self, event): + event.current_stream_wait() + + def dispatch_async( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: Optional[int] = None, + use_fp8: bool = False, + async_finish: bool = True, + ): + return dispatch_async_ll(self.token_dispatcher, hidden_states, topk_idx, num_experts, use_fp8, async_finish) + + def combine_async( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: tuple, + async_finish: bool, + ): + return self.token_dispatcher.combine_async(hidden_states, topk_idx, topk_weights, handle, async_finish) + + def fusedmoe_forward(self, state, up_weight, down_weight): + recv_hidden_states = state['recv_hidden_states'] + masked_m = state['recv_expert_count'] + hidden_shape = state['raw_hidden_shape'] + topk_idx = state['topk_idx'] + expected_m = (hidden_shape[0] * self.token_dispatcher.buffer_low_latency.group_size * topk_idx.shape[1] + + self.token_dispatcher.num_experts) // self.token_dispatcher.num_experts + return self.experts(recv_hidden_states, up_weight, down_weight, masked_m, expected_m) + + +def build_deepep_moe( + low_latency_mode: bool, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + top_k: int, + layer_idx: int = 0, + chunk_size: Optional[int] = 32 * 1024, + out_dtype: torch.dtype = torch.bfloat16, +): + if low_latency_mode: + return FusedMoELowLatency(ep_size=ep_size, + ep_group=ep_group, + num_experts=num_experts, + hidden_dim=hidden_dim, + layer_index=layer_idx, + out_dtype=out_dtype) + else: + return FusedMoENormal(ep_size=ep_size, + ep_group=ep_group, + num_experts=num_experts, + hidden_dim=hidden_dim, + layer_index=layer_idx, + top_k=top_k, + chunk_size=chunk_size, + out_dtype=out_dtype) + + +class FusedMoEEPImpl(TritonFusedMoEImpl): + """Fused moe implementation.""" + + def __init__( + self, + ep_size: int, + ep_group: dist.ProcessGroup, + top_k: int, + num_experts: int, + hidden_dim: int, + renormalize: bool = False, + layer_idx: int = 0, + out_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__(top_k, num_experts, renormalize) + self.num_experts = num_experts + self.ep_size = ep_size + self.ep_group = ep_group + self.hidden_dim = hidden_dim + self.layer_idx = layer_idx + self.out_dtype = out_dtype + + try: + import deep_gemm # noqa: F401 + except ImportError: + logger.exception('DeepGEMM is required for DeepEP MoE implementation.') + + # pre-allocate buffer + self.fusedmoe_build(True) + + def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor): + # gate_up_weights = gate_up_weights.transpose(1, 2).contiguous().transpose(1, 2) + # down_weights = down_weights.transpose(1, 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + gate_up_bias: torch.Tensor = None, + down_bias: torch.Tensor = None, + expert_list: List[int] = None, + act_func: Callable = None): + """forward.""" + assert act_func is None, 'Activation function is not supported in DeepEP MoE.' + hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights, + topk_ids) + + topk_weights = self.do_renormalize(topk_weights) + step_ctx = get_step_ctx_manager().current_context() + low_latency_mode = step_ctx.is_decoding + moe = self.fusedmoe_build(low_latency_mode) + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, down_weights, expert_list) + + out_states = gather_outputs_by_attn_tp(out_states, split_size) + return out_states + + def ep_expert_list(self, world_size: int, rank: int): + """Experts list of current rank.""" + if get_dist_manager().current_context().dist_config.enable_eplb: + raise NotImplementedError('float16/bfloat16 enable_eplb is not Implemented.') + else: + return super().ep_expert_list(world_size=world_size, rank=rank) + + def do_renormalize(self, topk_weights): + return _renormalize(topk_weights, self.renormalize) + + def fusedmoe_build(self, low_latency_mode: bool = False): + deepep_moe = build_deepep_moe(low_latency_mode, + self.ep_size, + self.ep_group, + self.num_experts, + self.hidden_dim, + self.top_k, + layer_idx=self.layer_idx, + chunk_size=16 * 1024, + out_dtype=self.out_dtype) + return deepep_moe + + +class TritonFusedMoEBuilder(FusedMoEBuilder): + """Triton fused moe builder.""" + + @staticmethod + def build( + top_k: int, + num_experts: int, + renormalize: bool = False, + hidden_dim: int = 1, + ep_size: int = 1, + ep_group: dist.ProcessGroup = None, + layer_idx: int = 0, + out_dtype: torch.dtype = torch.bfloat16, + ): + """Build from mlp.""" + if ep_size > 1: + return FusedMoEEPImpl(ep_size=ep_size, + ep_group=ep_group, + top_k=top_k, + num_experts=num_experts, + hidden_dim=hidden_dim, + renormalize=renormalize, + layer_idx=layer_idx, + out_dtype=out_dtype) + return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/cuda/moe/ep_utils.py b/lmdeploy/pytorch/backends/cuda/moe/ep_utils.py new file mode 100644 index 0000000000..f4c596a99c --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe/ep_utils.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from torch import distributed as dist + +from lmdeploy.pytorch.distributed import get_dist_manager + + +def split_inputs_by_attn_tp( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +): + """Split input by attn tp.""" + dist_ctx = get_dist_manager().current_context() + attn_tp = dist_ctx.dist_config.attn_tp + attn_rank = dist_ctx.attn_tp_group.rank + num_states = hidden_states.size(0) + + if attn_tp == 1 or attn_tp > num_states: + return hidden_states, topk_weights, topk_ids, None + + # split size + base = num_states // attn_tp + remain = num_states % attn_tp + split_size = [base + 1] * remain + [base] * (attn_tp - remain) + + # split inputs + hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank] + topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank] + topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank] + + return hidden_states, topk_weights, topk_ids, split_size + + +def gather_outputs_by_attn_tp(out_states: torch.Tensor, split_size: List[int]): + """Gather output by attn tp.""" + if split_size is None: + return out_states + + dist_ctx = get_dist_manager().current_context() + gpu_group = dist_ctx.attn_tp_group.gpu_group + new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1])) + new_out_states_list = list(new_out_states.split(split_size, dim=0)) + dist.all_gather(new_out_states_list, out_states, group=gpu_group) + return new_out_states diff --git a/lmdeploy/pytorch/backends/cuda/moe/w8a8.py b/lmdeploy/pytorch/backends/cuda/moe/w8a8.py new file mode 100644 index 0000000000..19358f9751 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe/w8a8.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List + +import torch + +from lmdeploy.pytorch.backends.moe import FusedMoEW8A8Builder, FusedMoEW8A8Impl +from lmdeploy.pytorch.kernels.cuda import fused_moe_w8a8 +from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8 +from lmdeploy.pytorch.models.q_modules import QTensor +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl): + """Triton fused moe w8a8 implementation.""" + + def __init__( + self, + top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16, + quant_dtype: torch.dtype = torch.int8, + ): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + self.out_dtype = out_dtype + self.quant_dtype = quant_dtype + + def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + # do not transpose weight for int8/fp8 + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + + if isinstance(hidden_states, torch.Tensor): + hidden_states = hidden_states.contiguous() + input_quant, input_scale = per_token_quant_int8(hidden_states, 1e-7, quant_dtype=self.quant_dtype) + else: + assert isinstance(hidden_states, QTensor) + input_quant, input_scale = (hidden_states.tensor, hidden_states.scale) + + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + return fused_moe_w8a8(input_quant, + input_scale, + gate_up_weights, + gate_up_scale, + down_weights, + down_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + out_dtype=self.out_dtype, + quant_dtype=self.quant_dtype, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize) + + +class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder): + """Triton fused moe w8a8 builder.""" + + @staticmethod + def build( + top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16, + quant_dtype: torch.dtype = torch.int8, + ): + """Build from mlp.""" + return TritonFusedMoEW8A8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + out_dtype=out_dtype, + quant_dtype=quant_dtype) diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index 99bc54f29c..ae10e29b47 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -67,6 +67,13 @@ class DlinferFusedMoEBuilder(FusedMoEBuilder): """Dlinfer fused moe builder.""" @staticmethod - def build(top_k: int, num_experts: int, renormalize: bool = False): + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + hidden_dim: int = 1, + ep_size: int = 1, + ep_group: torch.distributed.ProcessGroup = None, + layer_idx: int = 0, + out_dtype: torch.dtype = torch.bfloat16): """Build from mlp.""" return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index c12eadcf6e..ca0bfbdd29 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -39,10 +39,6 @@ def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tens """Update weights.""" return gate_up_weights, down_weights - def support_ep(self): - """Support expert parallelism.""" - return False - def ep_expert_list(self, world_size: int, rank: int): """Experts list of current rank.""" raise NotImplementedError('Not Implemented.') @@ -67,7 +63,14 @@ class FusedMoEBuilder(ABC): @staticmethod @abstractmethod - def build(top_k: int, num_experts: int, renormalize: bool = False): + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + hidden_dim: int = 1, + ep_size: int = 1, + ep_group: dist.ProcessGroup = None, + layer_idx: int = 0, + out_dtype: torch.dtype = torch.bfloat16): """Build from mlp.""" raise NotImplementedError @@ -80,10 +83,6 @@ def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tens """Update weights.""" return gate_up_weights, down_weights, gate_up_scale, down_scale - def support_ep(self): - """Support expert parallelism.""" - return False - def ep_expert_list(self, world_size: int, rank: int): """Experts list of current rank.""" raise NotImplementedError('Not Implemented.') @@ -125,10 +124,6 @@ def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tens """Update weights.""" return gate_up_weights, down_weights, gate_up_scale, down_scale - def support_ep(self): - """Support expert parallelism.""" - return False - def ep_expert_list(self, world_size: int, rank: int): """Experts list of current rank.""" raise NotImplementedError('Not Implemented.') diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index f36aeced3e..bc75d99627 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -126,6 +126,9 @@ def _patched_get_env( # we don't need to read this, it would be passed to ray workers # If Ray is launched from outside, it may fail to access the environment variables. os.getenv('DEEPEP_MAX_BATCH_SIZE', None) + os.getenv('DEEPEP_MAX_TOKENS_PER_RANK', None) + os.getenv('DEEPEP_ENABLE_MNNVL', None) + os.getenv('DEEPEP_MODE', 'auto') # deepgemm os.getenv('DG_JIT_DEBUG', '0') diff --git a/lmdeploy/pytorch/kernels/cuda/ep_moe.py b/lmdeploy/pytorch/kernels/cuda/ep_moe.py deleted file mode 100644 index ad620ead1c..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/ep_moe.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from sglang -from typing import List, Optional - -import torch -import triton -import triton.language as tl - - -@triton.jit -def silu_and_mul_triton_kernel( - gateup_output, - down_input, - hidden_size, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - BLOCK_SIZE: tl.constexpr, -): - InDtype = gateup_output.dtype.element_ty - OutDtype = down_input.dtype.element_ty - - half_hidden_size = hidden_size // 2 - - pid = tl.program_id(0) - expert_id = tl.load(reorder_topk_ids + pid) - if expert_id >= start_expert_id and expert_id <= end_expert_id: - gateup_output_ptr = gateup_output + pid * hidden_size - gate_output_ptr = gateup_output_ptr - up_output_ptr = gateup_output_ptr + half_hidden_size - down_input_ptr = down_input + pid * half_hidden_size - - if scales is not None: - scale = tl.load(scales + expert_id - start_expert_id) - scale = (1 / scale).to(InDtype) - else: - scale = 1 - - for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) - mask = offset < half_hidden_size - - gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) - up_output = tl.load(up_output_ptr + offset, mask=mask) - - # silu & mul & quantize - gate_output = gate_output * tl.sigmoid(gate_output) - gate_output = gate_output.to(InDtype) - - silu_mul_output = gate_output * up_output * scale - silu_mul_output = silu_mul_output.to(OutDtype) - tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) - - -@triton.jit -def compute_m_range( - pid, - batch_size, - seg_indptr, - weight_indices, - m_num_tiles_indptr, - BLOCK_SIZE_M: tl.constexpr, -): - idx = 0 - for bs in range(batch_size): - tiles = tl.load(m_num_tiles_indptr + bs) - if pid >= tiles: - idx = bs - - idx_start = tl.load(m_num_tiles_indptr + idx) - - m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M - m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) - expert_id = tl.load(weight_indices + idx) - return m_range_start, m_range_end, expert_id - - -@triton.jit -def grouped_gemm_triton_kernel( - a, - b, - c, - batch_size, - N, - K, - seg_indptr, - weight_indices, - m_num_tiles_indptr, - scale_a, - scale_b, - use_fp8_w8a8: tl.constexpr, - group_n: tl.constexpr, - group_k: tl.constexpr, - a_stride_0: tl.constexpr, - b_stride_0: tl.constexpr, - b_stride_1: tl.constexpr, - as_stride_0: tl.constexpr, - as_stride_1: tl.constexpr, - bs_stride_0: tl.constexpr, - bs_stride_2: tl.constexpr, - bs_stride_1: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - c_dtype = c.dtype.element_ty - - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - total_m_block = tl.load(m_num_tiles_indptr + batch_size) - if pid_m >= total_m_block: - return - - m_range_start, m_range_end, expert_id = compute_m_range(pid_m, batch_size, seg_indptr, weight_indices, - m_num_tiles_indptr, BLOCK_SIZE_M) - if m_range_end - m_range_start == 0: - return - - n_range_start = pid_n * BLOCK_SIZE_N - n_range_end = min(n_range_start + BLOCK_SIZE_N, N) - - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) - offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] - b_ptr = b + ((expert_id * b_stride_0) + (n_range_start + offs_bn[:, None]) * b_stride_1 + offs_k[None, :]) - - if group_k > 0 and group_n > 0: - a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 - offs_bsn = (n_range_start + offs_bn) // group_n - b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a_tile = tl.load(a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0) - b_tile = tl.load(b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0) - - if group_k > 0 and group_n > 0: - k_start = k * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) - b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) - accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] - else: - accumulator = tl.dot(a_tile, b_tile.T, accumulator) - a_ptr += BLOCK_SIZE_K - b_ptr += BLOCK_SIZE_K - - if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): - scale_a_value = tl.load(scale_a + expert_id) - scale_b_value = tl.load(scale_b + expert_id) - accumulator *= scale_a_value * scale_b_value - - c_tile = accumulator.to(c_dtype) - - offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) - c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] - c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) - tl.store(c_ptr, c_tile, mask=c_mask) - - -@triton.jit -def compute_m_num_tiles_indptr(m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): - for bs in range(batch_size): - m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) - cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) - pre_num_tiles = tl.load(m_num_tiles_indptr + bs) - tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) - - -def grouped_gemm_triton( - a: torch.Tensor, - b: torch.Tensor, - c: torch.Tensor, - batch_size: int, - weight_column_major: bool, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - scale_a: torch.Tensor = None, - scale_b: torch.Tensor = None, - block_shape: Optional[List[int]] = None, -): - assert weight_column_major - if use_fp8_w8a8 and block_shape is None: - assert scale_a is not None and scale_b is not None - - if block_shape is not None: - assert len(block_shape) == 2 - block_n, block_k = block_shape[0], block_shape[1] - assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] - assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] - assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] - - # TODO: adjust config or tune kernel - # Reduce block size to prevent L40 shared memory overflow. - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - } - - m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) - compute_m_num_tiles_indptr[(1, )](m_num_tiles_indptr, seg_indptr, batch_size, config['BLOCK_SIZE_M']) - - def grid(META): - return ( - triton.cdiv(a.size(0), META['BLOCK_SIZE_M']) + batch_size, - triton.cdiv(b.size(1), META['BLOCK_SIZE_N']), - ) - - grouped_gemm_triton_kernel[grid]( - a, - b, - c, - batch_size, - b.size(1), - b.size(2), - seg_indptr, - weight_indices, - m_num_tiles_indptr, - scale_a, - scale_b, - use_fp8_w8a8, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - a.stride(0), - b.stride(0), - b.stride(1), - scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0, - scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0, - scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, - scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, - scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, - **config, - ) - return c - - -@triton.jit -def _silu_and_mul_post_quant_kernel( - input_ptr, - stride_input_0, - stride_input_1, - stride_input_2, - output_ptr, - stride_output_0, - stride_output_1, - stride_output_2, - output_scale_ptr, - stride_output_scale_0, - stride_output_scale_1, - stride_output_scale_2, - masked_m_ptr, - size_n, - fp8_max, - fp8_min, - BLOCK_N: tl.constexpr, - NUM_STAGE: tl.constexpr, -): - expert_id = tl.program_id(2) - token_id = tl.program_id(1) - hidden_dim_block_index = tl.program_id(0) - - block_num_per_expert = tl.num_programs(1) - - token_num_cur_expert = tl.load(masked_m_ptr + expert_id) - - stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) - stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) - stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) - stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) - - offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) - input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d - output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d - output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + - hidden_dim_block_index * stride_output_scale_2) - - for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): - gate = tl.load( - input_ptr_offs + token_index * stride_input_1, - mask=offs_in_d < size_n, - other=0.0, - ).to(tl.float32) - up = tl.load( - input_ptr_offs + token_index * stride_input_1 + size_n, - mask=offs_in_d < size_n, - other=0.0, - ) - gate = gate / (1 + tl.exp(-gate)) - gate = gate.to(input_ptr.dtype.element_ty) - gate_up = up * gate - _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) - output_s = _absmax / fp8_max - output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(output_ptr.dtype.element_ty) - tl.store( - output_ptr_offs + token_index * stride_output_1, - output_q, - mask=offs_in_d < size_n, - ) - tl.store( - output_scale_offs + token_index * stride_output_scale_1, - output_s, - ) - - -def silu_and_mul_masked_post_quant_fwd( - input: torch.Tensor, - output: torch.Tensor, - output_scale: torch.Tensor, - quant_group_size: int, - masked_m: torch.Tensor, -): - assert input.is_contiguous() - assert output.dtype == torch.float8_e4m3fn - assert output.is_contiguous() - assert len(input.shape) == 3 - assert input.shape[0] == masked_m.shape[0] - assert input.shape[-1] % 2 == 0 - size_n = input.shape[-1] // 2 - assert size_n % quant_group_size == 0 - expert_num = len(masked_m) - if expert_num < 4: - BLOCK_NUM_PER_EXPERT = 64 - else: - BLOCK_NUM_PER_EXPERT = 32 - BLOCK_N = quant_group_size - num_warps = 1 - NUM_STAGES = 6 - hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) - assert BLOCK_N % quant_group_size == 0 - grid = ( - hidden_dim_split_block_num, - BLOCK_NUM_PER_EXPERT, - expert_num, - ) - finfo = torch.finfo(torch.float8_e4m3fn) - fp8_max = finfo.max - fp8_min = -fp8_max - _silu_and_mul_post_quant_kernel[grid]( - input, - *input.stride(), - output, - *output.stride(), - output_scale, - *output_scale.stride(), - masked_m, - size_n, - fp8_max, - fp8_min, - BLOCK_N=BLOCK_N, - NUM_STAGE=NUM_STAGES, - num_warps=num_warps, - ) diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py b/lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py new file mode 100644 index 0000000000..b7213c84ba --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from dlblas: https://github.com/DeepLink-org/DLBlas +from typing import List, Optional + +import torch +import triton +import triton.language as tl + +from .activation import silu_and_mul + + +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +@triton.jit +def _fwd_kernel_ep_scatter_2( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + for token_id in range(start_token_id, total_token_num, grid_num): + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + for topk_index in tl.range(0, topk_num, 1, num_stages=4): + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) + if expert_id >= 0: + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index = dest_token_index.to(tl.int64) + tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index) + output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0 + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + + +# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py +def ep_scatter( + recv_x: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + BLOCK_E = 128 # token num of per expert is aligned to 128 + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) + grid = num_experts + assert m_indices.shape[0] % BLOCK_E == 0 + _fwd_kernel_ep_scatter_1[(grid, )]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=BLOCK_E, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + grid = min(recv_topk.shape[0], 1024 * 8) + _fwd_kernel_ep_scatter_2[(grid, )]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + ) + return + + +@triton.jit +def _fwd_kernel_ep_gather( + total_token_num, + input_tensor, + input_tensor_stride0, + input_tensor_stride1, + recv_topk_ids, + recv_topk_ids_stride0, + recv_topk_ids_stride1, + recv_topk_weight, + recv_topk_weight_stride0, + recv_topk_weight_stride1, + input_index, + input_index_stride0, + input_index_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + topk_num: tl.constexpr, + BLOCK_D: tl.constexpr, +): + cur_block = tl.program_id(0) + start_cur_token = tl.program_id(1) + grid_num = tl.num_programs(1) + for cur_token in range(start_cur_token, total_token_num, grid_num): + off_d = tl.arange(0, BLOCK_D) + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) + for topk_index in range(0, topk_num): + expert_id = tl.load(recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index) + if expert_id >= 0: + source_token_index = tl.load(input_index + cur_token * input_index_stride0 + topk_index) + acc_weight = tl.load(recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index) + tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d) + accumulator += tmp.to(tl.float32) * acc_weight + tl.store( + output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d, + accumulator.to(output_tensor.dtype.element_ty), + ) + + +@torch.no_grad() +def ep_gather( + input_tensor: torch.Tensor, + recv_topk_ids: torch.Tensor, + recv_topk_weight: torch.Tensor, + input_index: torch.Tensor, + output_tensor: torch.Tensor, +): + BLOCK_D = 1024 # block size of quantization + num_warps = 2 + num_tokens = output_tensor.shape[0] + hidden_size = input_tensor.shape[1] + assert hidden_size % BLOCK_D == 0 + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) + _fwd_kernel_ep_gather[grid]( + num_tokens, + input_tensor, + input_tensor.stride(0), + input_tensor.stride(1), + recv_topk_ids, + recv_topk_ids.stride(0), + recv_topk_ids.stride(1), + recv_topk_weight, + recv_topk_weight.stride(0), + recv_topk_weight.stride(1), + input_index, + input_index.stride(0), + input_index.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + topk_num=recv_topk_ids.shape[1], + num_warps=num_warps, + BLOCK_D=BLOCK_D, + ) + return + + +def _deepgemm_grouped_bf16_nt_contiguous( + x: torch.Tensor, + w: torch.Tensor, + out: torch.Tensor, + m_indices: torch.Tensor, +): + from lmdeploy.pytorch.third_party import deep_gemm + return deep_gemm.m_grouped_bf16_gemm_nt_contiguous(x, w, out, m_indices) + + +def fused_moe_v3( + hidden_states: torch.Tensor, + topk_idx, + topk_weights, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + num_recv_tokens_per_expert: Optional[List[int]], +): + if num_recv_tokens_per_expert is None: + return hidden_states + all_tokens = sum(num_recv_tokens_per_expert) + if all_tokens <= 0: + return hidden_states + M, K = hidden_states.size() + N = w13_weight.size(1) + gather_out = torch.empty_like(hidden_states) + input_tensor = hidden_states.new_empty((all_tokens, K)) + m_indices = hidden_states.new_empty(all_tokens, dtype=torch.int32) + output_index = torch.empty_like(topk_idx) + num_recv_tokens_per_expert_gpu = torch.tensor( + num_recv_tokens_per_expert, + dtype=torch.int32, + pin_memory=True, + device='cpu', + ).cuda(non_blocking=True) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) + ep_scatter( + hidden_states, + topk_idx, + num_recv_tokens_per_expert_gpu, + expert_start_loc, + input_tensor, + m_indices, + output_index, + ) + del hidden_states + gateup_output = gather_out.new_empty((all_tokens, N)) + _deepgemm_grouped_bf16_nt_contiguous(input_tensor, w13_weight, gateup_output, m_indices) + down_input = gateup_output.new_empty(( + all_tokens, + N // 2, + )) + down_input = silu_and_mul(gateup_output.view(-1, N), down_input) + down_output = gather_out.new_empty((all_tokens, K)) + _deepgemm_grouped_bf16_nt_contiguous( + down_input, + w2_weight, + down_output, + m_indices, + ) + ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + return gather_out diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py deleted file mode 100644 index 3330050207..0000000000 --- a/lmdeploy/pytorch/nn/moe.py +++ /dev/null @@ -1,1093 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from collections import defaultdict -from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional - -import torch -from torch import nn - -import lmdeploy.pytorch.distributed as dist -from lmdeploy.pytorch.config import TPMode -from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank -from lmdeploy.pytorch.model_inputs import get_step_ctx_manager - -from ..backends import OpType, get_backend -from .quant_utils import quant_blocked_fp8 -from .utils import div_up - - -class MoeType(Enum): - """Batch ecex type.""" - Default = auto() - DSSyncDecode = auto() - DSAsyncDecode = auto() - DSSyncPrefill = auto() - DSAsyncPrefill = auto() - - -class SoftmaxTopK(nn.Module): - """Softmax topk.""" - - def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1): - super().__init__() - self.top_k = top_k - impl_builder = get_backend().get_layer_impl_builder(OpType.SoftmaxTopK) - self.impl = impl_builder.build(top_k, dim, n_groups=n_groups) - - def forward(self, x: torch.Tensor): - """forward.""" - return self.impl.forward(x) - - -def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int, dtype: torch.dtype, device: torch.device): - """Create weights.""" - gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim), dtype=dtype, device=device) - down_weights = torch.empty((num_experts, hidden_dim, ffn_dim), dtype=dtype, device=device) - return gate_up_weights, down_weights - - -def _update_args(hidden_dim: int, ffn_dim: int): - """Update args.""" - world_size, _ = get_tp_world_rank('moe') - assert ffn_dim % world_size == 0 - ffn_dim = ffn_dim // world_size - return hidden_dim, ffn_dim - - -def _split_size(size: int, world_size: int, align: int): - size = size // align - assert size >= world_size - base = size // world_size - remain = size % world_size - split_size = [base + 1] * remain + [base] * (world_size - remain) - split_size = [s * align for s in split_size] - return split_size - - -class MoEForwardDPTP: - - def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): - """MoE forward dp tp.""" - self.gemm_func = gemm_func - self.dist_ctx = get_dist_manager().current_context() - self.dist_config = self.dist_ctx.dist_config - self.tp = self.dist_config.moe_tp - self.attn_tp = self.dist_config.attn_tp - - tp_group = self.dist_ctx.moe_tp_group - self.rank = tp_group.rank - self.gather_rank = self.rank // self.attn_tp - self.gather_group = tp_group.gpu_gather_group - self.tp_group = tp_group.gpu_group - self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 - - def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - tp_sizes: List[int]): - """All gather.""" - hidden_states, h0 = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) - topk_weights, h1 = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True) - topk_ids, h2 = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True) - return hidden_states, topk_weights, topk_ids, (h0, h1, h2) - - def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): - """Reduce scatter.""" - hidden_states_list = list(hidden_states.split(tp_sizes, -2)) - cur_out_states = hidden_states_list[self.gather_rank] - out_states.copy_(cur_out_states) - hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] - hidden_states_list[self.rank] = out_states - handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) - return out_states, handle - - def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - output_states: torch.Tensor, tp_sizes: List[int], handles: List[dist.Work]): - """Gemm and reduce scatter.""" - for handle in handles: - handle.wait() - cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids) - return self.reduce_scatter(cur_out, output_states, tp_sizes) - - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor): - """forward.""" - - def __slice_tensor(tensor: torch.Tensor, slice_size: int): - """Slice tensor.""" - cur_tensor = tensor[:slice_size] - tensor = tensor[slice_size:] - return cur_tensor, tensor - - def __slice_and_gather(): - """Slice and gather.""" - nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states - cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) - tp_sizes -= cur_tp_sizes - cur_tp_sizes = cur_tp_sizes.tolist() - - slice_size = cur_tp_sizes[self.gather_rank] - cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) - cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size) - cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size) - cur_output, output_states = __slice_tensor(output_states, slice_size) - - # all gather - cur_hidden_states, cur_topk_weights, cur_topk_ids, handles = self.all_gather( - cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes) - return dict(hidden_states=cur_hidden_states, - topk_weights=cur_topk_weights, - topk_ids=cur_topk_ids, - output_states=cur_output, - handles=handles, - tp_sizes=cur_tp_sizes) - - step_ctx = get_step_ctx_manager().current_context() - tp_sizes = step_ctx.dp_meta.moe_tp_sizes - tp_sizes = torch.tensor(tp_sizes) - max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) - - output_states = torch.empty_like(hidden_states) - return_states = output_states - - # pre - cur_inputs = __slice_and_gather() - - out_handles = [] - # main loop - while tp_sizes.sum() > 0: - next_inputs = __slice_and_gather() - _, handle = self._gemm_and_reduce_scatter(**cur_inputs) - out_handles.append(handle) - cur_inputs = next_inputs - - # post - _, handle = self._gemm_and_reduce_scatter(**cur_inputs) - out_handles.append(handle) - for handle in out_handles: - handle.wait() - return return_states - - -class LinearWeights(nn.Module): - """Fused moe linear weights.""" - - def __init__(self, - num_experts: int, - in_features: int, - out_features: int, - weight_type: str, - dtype: torch.dtype, - device: torch.device, - bias: bool = False, - expert_list: List[int] = None, - ep: bool = False): - super().__init__() - weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device) - weight = torch.nn.Parameter(weight, requires_grad=False) - self.register_parameter('weight', weight) - - if bias: - bias = torch.empty((num_experts, out_features), dtype=dtype, device=device) - bias = torch.nn.Parameter(bias, requires_grad=False) - self.register_parameter('bias', bias) - else: - self.bias = None - - self.ep = ep - self.expert_list = expert_list - self.weight_type = weight_type - self.half_out = out_features // 2 - - self.setup_weight_loader() - - def setup_weight_loader(self): - """Setup weight loader.""" - if self.ep: - self.expert_map = defaultdict(list) - for idx, eid in enumerate(self.expert_list): - self.expert_map[eid].append(idx) - self.weight.weight_loader = self.weight_loader_ep - else: - self.weight.weight_loader = self.weight_loader_tp - - if self.bias is not None: - self.bias.weight_loader = self.weight_loader_ep if self.ep else self.weight_loader_tp - - def update_weight(self, weight: torch.Tensor): - """Update weight.""" - weight_loader = self.weight.weight_loader - weight = torch.nn.Parameter(weight, requires_grad=False) - weight.weight_loader = weight_loader - self.register_parameter('weight', weight) - - def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): - """Weight loader.""" - world_size, rank = get_tp_world_rank('moe') - if shard_id == 'gate': - param_data = param.data[expert_id, :self.half_out] - weight = loaded_weight.chunk(world_size, dim=0)[rank] - elif shard_id == 'up': - param_data = param.data[expert_id, self.half_out:] - weight = loaded_weight.chunk(world_size, dim=0)[rank] - elif shard_id == 'down': - param_data = param.data[expert_id] - # weight is not contiguous, chunk and copy in cpu is slow - weight = loaded_weight.to(param_data.device) - if weight.dim() > 1: - weight = weight.chunk(world_size, dim=1)[rank] - elif weight.dim() == 1 and rank != 0: - # bias with rank>0 should be 0 - weight = torch.zeros_like(weight) - else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(weight) - - def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): - """Weight loader.""" - expert_list = self.expert_list - if expert_id not in expert_list: - return - - expert_map = self.expert_map - param_ids = expert_map[expert_id] - for param_id in param_ids: - if shard_id == 'gate': - param_data = param.data[param_id, :self.half_out] - elif shard_id == 'up': - param_data = param.data[param_id, self.half_out:] - elif shard_id == 'down': - param_data = param.data[param_id] - else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(loaded_weight) - - -def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None): - dist_config = get_dist_manager().current_config() - tp = dist_config.moe_tp - if tp == 1: - return hidden_states, topk_weights, topk_ids - - tp_mode = dist_config.moe_tp_mode - if tp_mode == TPMode.DEFAULT: - return hidden_states, topk_weights, topk_ids - elif tp_mode == TPMode.DP_TP: - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - tp_sizes = dp_meta.moe_tp_sizes - hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group) - topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group) - topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group) - else: - raise RuntimeError('Not supported.') - - return hidden_states, topk_weights, topk_ids - - -def _moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None): - dist_config = get_dist_manager().current_config() - if dist_config.moe_tp == 1: - return ret - - if tp_mode == TPMode.DEFAULT: - dist.all_reduce(ret, group=group) - return ret - elif tp_mode == TPMode.DP_TP: - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - tp_size = dp_meta.moe_tp_sizes - ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group) - return ret - else: - raise RuntimeError('Not supported.') - - -class FusedMoE(nn.Module): - """Fused moe.""" - - def __init__(self, - hidden_dim: int, - ffn_dim: int, - num_experts: int, - top_k: int, - bias: bool = False, - renormalize: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - all_reduce: bool = True, - enable_ep: bool = False, - act_func: Callable = None): - super().__init__() - if device is None: - device = torch.device('cpu') - if dtype is None: - dtype = torch.float16 - self.init_tp_args(all_reduce, enable_ep) - - impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) - self.impl = impl_builder.build(top_k, num_experts, renormalize) - - enable_ep = enable_ep and self.impl.support_ep() - if enable_ep: - world_size, rank = get_tp_world_rank('moe') - expert_list = self.impl.ep_expert_list(world_size, rank) - num_experts = len(expert_list) - else: - hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) - expert_list = None - self.expert_list = expert_list - self.gate_up = LinearWeights(num_experts, - hidden_dim, - ffn_dim * 2, - weight_type='gate_up', - dtype=dtype, - device=device, - bias=bias, - expert_list=expert_list, - ep=enable_ep) - self.down = LinearWeights( - num_experts, - ffn_dim, - hidden_dim, - weight_type='down', - dtype=dtype, - device=device, - bias=bias, - expert_list=expert_list, - ep=enable_ep, - ) - - self.hidden_dim = hidden_dim - self.ffn_dim = ffn_dim - self.num_experts = num_experts - self.dtype = dtype - self.device = device - self.enable_ep = enable_ep - self.act_func = act_func - - def init_tp_args(self, all_reduce: bool, enable_ep: bool): - """Init tp args.""" - tp, tp_rank = get_tp_world_rank('moe') - dist_ctx = get_dist_manager().current_context() - dist_cfg = dist_ctx.dist_config - _, tp_mode = dist_cfg.get_tp_by_layer('moe') - tp = 1 if enable_ep else tp - tp_rank = 0 if enable_ep else tp_rank - all_reduce = all_reduce if tp > 1 else False - all_reduce = False if enable_ep else all_reduce - - self.tp = tp - self.tp_rank = tp_rank - self.tp_mode = tp_mode - self.all_reduce = all_reduce - self.tp_group = dist_ctx.moe_tp_group.gpu_group - self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group - - if self.tp > 1 and self.tp_mode == TPMode.DP_TP: - - def __gemm_func(hidden_states, topk_weights, topk_ids): - return self.gemm( - dict( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_idx=topk_ids, - moe_type=MoeType.Default, - ))['hidden_states'] - - self.forward_dptp = MoEForwardDPTP(__gemm_func) - - def update_weights(self): - """Update weights.""" - gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight) - self.gate_up.update_weight(gate_up_weights) - self.down.update_weight(down_weights) - - def gemm(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Gemm.""" - hidden_states = inputs['hidden_states'] - topk_weights = inputs['topk_weights'] - topk_ids = inputs['topk_idx'] - - ret = self.impl.forward(hidden_states, - topk_weights, - topk_ids, - self.gate_up.weight, - self.down.weight, - self.gate_up.bias, - self.down.bias, - self.expert_list, - act_func=self.act_func) - return dict(hidden_states=ret) - - def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, - topk_weights, - topk_ids, - group=self.gather_group) - - ret = self.impl.forward(hidden_states, - topk_weights, - topk_ids, - self.gate_up.weight, - self.down.weight, - self.gate_up.bias, - self.down.bias, - self.expert_list, - act_func=self.act_func) - if self.all_reduce: - ret = _moe_reduce(ret, rank=self.tp_rank, tp_mode=self.tp_mode, group=self.tp_group) - return ret - - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - """forward.""" - if self.tp > 1 and self.tp_mode == TPMode.DP_TP: - return self.forward_dptp.forward(hidden_states, topk_weights, topk_ids) - return self.forward_default(hidden_states, topk_weights, topk_ids) - - -class LinearWeightsW8A8(LinearWeights): - """Fused moe linear w8a8 weights.""" - - def __init__(self, - num_experts: int, - in_features: int, - out_features: int, - weight_type: str, - device: torch.device, - expert_list: List[int] = None, - ep: bool = False, - quant_dtype: torch.dtype = torch.int8): - super().__init__( - num_experts=num_experts, - in_features=in_features, - out_features=out_features, - weight_type=weight_type, - dtype=quant_dtype, - device=device, - expert_list=expert_list, - ep=ep, - ) - scale = torch.empty((num_experts, out_features, 1), dtype=torch.float32, device=device) - scale = torch.nn.Parameter(scale, requires_grad=False) - self.register_parameter('scale', scale) - - if self.ep: - self.scale.weight_loader = self.weight_loader_ep - else: - self.scale.weight_loader = self.weight_loader_scale_tp - - def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): - """Update weight.""" - super().update_weight(weight=weight) - weight_loader = self.scale.weight_loader - scale = torch.nn.Parameter(scale, requires_grad=False) - scale.weight_loader = weight_loader - self.register_parameter('scale', scale) - - def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank('moe') - if shard_id == 'gate': - param_data = param.data[expert_id, :self.half_out] - weight = loaded_weight.chunk(world_size, dim=0)[rank] - elif shard_id == 'up': - param_data = param.data[expert_id, self.half_out:] - weight = loaded_weight.chunk(world_size, dim=0)[rank] - elif shard_id == 'down': - param_data = param.data[expert_id] - weight = loaded_weight - else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - weight = weight.to(param.dtype) - param_data.copy_(weight) - - -class FusedMoEW8A8(nn.Module): - """Fused moe w8a8.""" - - def __init__(self, - hidden_dim: int, - ffn_dim: int, - num_experts: int, - top_k: int, - renormalize: bool = False, - dtype: Optional[torch.dtype] = None, - quant_dtype: Optional[torch.dtype] = torch.int8, - device: Optional[torch.device] = None, - all_reduce: bool = True, - enable_ep: bool = False): - super().__init__() - - if device is None: - device = torch.device('cpu') - dtype = torch.float16 if dtype is None else dtype - - impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEW8A8) - self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype, quant_dtype=quant_dtype) - - enable_ep = enable_ep and self.impl.support_ep() - if enable_ep: - world_size, rank = get_tp_world_rank('moe') - expert_list = self.impl.ep_expert_list(world_size, rank) - num_experts = len(expert_list) - else: - hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) - expert_list = None - self.expert_list = expert_list - - self.gate_up = LinearWeightsW8A8(num_experts, - hidden_dim, - ffn_dim * 2, - weight_type='gate_up', - device=device, - expert_list=expert_list, - ep=enable_ep, - quant_dtype=quant_dtype) - self.down = LinearWeightsW8A8(num_experts, - ffn_dim, - hidden_dim, - weight_type='down', - device=device, - expert_list=expert_list, - ep=enable_ep, - quant_dtype=quant_dtype) - - self.hidden_dim = hidden_dim - self.ffn_dim = ffn_dim - self.num_experts = num_experts - self.dtype = dtype - self.device = device - world_size, _ = get_tp_world_rank('moe') - if world_size == 1: - all_reduce = False - self.all_reduce = all_reduce - - def update_weights(self): - """Update weights.""" - (gate_up_weights, down_weights, gate_up_scale, - down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.scale, - self.down.scale) - self.gate_up.update_weight(gate_up_weights, gate_up_scale) - self.down.update_weight(down_weights, down_scale) - - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - ret = self.impl.forward(hidden_states, topk_weights, topk_ids, self.gate_up.weight, self.gate_up.scale, - self.down.weight, self.down.scale, self.expert_list) - if self.all_reduce: - dist.all_reduce(ret) - return ret - - -class LinearWeightsBlockedF8(LinearWeights): - """Fused moe linear blocked fp8 weights.""" - - def __init__(self, - num_experts: int, - in_features: int, - out_features: int, - weight_type: str, - block_size: int, - dtype: torch.dtype, - device: torch.device, - bias: bool = False, - expert_list: List[int] = None, - ep: bool = False): - super().__init__( - num_experts=num_experts, - in_features=in_features, - out_features=out_features, - weight_type=weight_type, - dtype=dtype, - device=device, - bias=bias, - expert_list=expert_list, - ep=ep, - ) - self.block_size = block_size - weight_scale_inv = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)), - dtype=torch.float32, - device=device) - weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) - self.register_parameter('weight_scale_inv', weight_scale_inv) - - if self.ep: - self.weight._base_weight_loader = self.weight.weight_loader - self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep - else: - self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8 - self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp - self.weight.weight_loader = self.weight_loader_with_quant - - def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor): - """Update weight.""" - super().update_weight(weight=weight) - weight_loader = self.weight_scale_inv.weight_loader - weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) - weight_scale_inv.weight_loader = weight_loader - self.register_parameter('weight_scale_inv', weight_scale_inv) - - def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - expert_list = self.expert_list - if expert_id not in expert_list: - return - expert_ids = self.expert_map[expert_id] - for expert_id in expert_ids: - self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) - - def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int): - """Chunk with align.""" - split_size = _split_size(weight.size(dim), world_size, align) - return weight.split(split_size, dim=dim)[rank] - - def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """Weight loader.""" - world_size, rank = get_tp_world_rank('moe') - if shard_id == 'gate': - param_data = param.data[expert_id, :self.half_out] - weight = self._chunk_weight_tp(loaded_weight, - dim=0, - world_size=world_size, - rank=rank, - align=self.block_size) - elif shard_id == 'up': - param_data = param.data[expert_id, self.half_out:] - weight = self._chunk_weight_tp(loaded_weight, - dim=0, - world_size=world_size, - rank=rank, - align=self.block_size) - elif shard_id == 'down': - param_data = param.data[expert_id] - # weight is not contiguous, chunk and copy in cpu is slow - weight = loaded_weight.to(param_data.device) - if weight.dim() > 1: - weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size) - elif weight.dim() == 1 and rank != 0: - # bias with rank>0 should be 0 - weight = torch.zeros_like(weight) - else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(weight) - - def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank('moe') - block_size = self.block_size - half_out = self.half_out // block_size - if shard_id == 'gate': - param_data = param.data[expert_id, :half_out] - weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) - elif shard_id == 'up': - param_data = param.data[expert_id, half_out:] - weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) - elif shard_id == 'down': - param_data = param.data[expert_id] - loaded_weight = loaded_weight.to(param_data.device) - weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1) - else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(weight) - - def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """Weight load with quant.""" - if loaded_weight.dtype != param.dtype: - # quant loaded weight - quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device), param.dtype, self.block_size) - self.weight._base_weight_loader(self.weight, quanted_weight, expert_id, shard_id) - self.weight_scale_inv.weight_loader(self.weight_scale_inv, scaling, expert_id, shard_id) - else: - return self.weight._base_weight_loader(param, loaded_weight, expert_id, shard_id) - - -class FusedMoEBlockedF8(nn.Module): - """Fused moe blocked f8.""" - - def __init__(self, - hidden_dim: int, - ffn_dim: int, - num_experts: int, - top_k: int, - bias: bool = False, - renormalize: bool = False, - fp8_dtype: torch.dtype = torch.float8_e4m3fn, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - all_reduce: bool = True, - enable_ep: bool = False, - layer_idx: int = 0, - act_func: Callable = None): - super().__init__() - if device is None: - device = torch.device('cpu') - dtype = torch.float16 if dtype is None else dtype - self.block_size = 128 - self.init_tp_args(all_reduce, enable_ep) - dist_ctx = get_dist_manager().current_context() - self.ep_size, rank = get_ep_world_rank() - impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8) - self.impl = impl_builder.build(top_k, - num_experts, - hidden_dim, - renormalize, - block_size=self.block_size, - ep_size=self.ep_size, - ep_group=dist_ctx.ep_gpu_group, - out_dtype=dtype, - layer_idx=layer_idx, - custom_gateup_act=act_func is not None) - - if self.ep_size > 1: - expert_list = self.impl.ep_expert_list(self.ep_size, rank) - num_experts = len(expert_list) - else: - hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size) - expert_list = None - self.expert_list = expert_list - - self.gate_up = LinearWeightsBlockedF8(num_experts, - hidden_dim, - ffn_dim * 2, - weight_type='gate_up', - block_size=self.block_size, - dtype=fp8_dtype, - device=device, - bias=bias, - expert_list=expert_list, - ep=self.ep_size > 1) - self.down = LinearWeightsBlockedF8( - num_experts, - ffn_dim, - hidden_dim, - weight_type='down', - block_size=self.block_size, - dtype=fp8_dtype, - device=device, - bias=bias, - expert_list=expert_list, - ep=self.ep_size > 1, - ) - - self.hidden_dim = hidden_dim - self.ffn_dim = ffn_dim - self.num_experts = num_experts - self.dtype = dtype - self.device = device - self.act_func = act_func - - @staticmethod - def _update_args(hidden_dim: int, ffn_dim: int, align: int): - """Update args.""" - world_size, rank = get_tp_world_rank('moe') - split_size = _split_size(ffn_dim, world_size, align) - ffn_dim = split_size[rank] - return hidden_dim, ffn_dim - - def init_tp_args(self, all_reduce: bool, enable_ep: bool): - """Init tp args.""" - tp, tp_rank = get_tp_world_rank('moe') - dist_ctx = get_dist_manager().current_context() - dist_cfg = dist_ctx.dist_config - _, tp_mode = dist_cfg.get_tp_by_layer('moe') - tp = 1 if enable_ep else tp - tp_rank = 0 if enable_ep else tp_rank - all_reduce = all_reduce if tp > 1 else False - all_reduce = False if enable_ep else all_reduce - - self.tp = tp - self.tp_rank = tp_rank - self.tp_mode = tp_mode - self.all_reduce = all_reduce - self.tp_group = dist_ctx.moe_tp_group.gpu_group - self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group - if self.tp > 1 and self.tp_mode == TPMode.DP_TP: - - def __gemm_func(hidden_states, topk_weights, topk_ids): - return self.gemm( - dict( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_idx=topk_ids, - moe_type=MoeType.Default, - ))['hidden_states'] - - self.forward_dptp = MoEForwardDPTP(__gemm_func) - - def update_weights(self): - """Update weights.""" - (gate_up_weights, down_weights, gate_up_scale, - down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.weight_scale_inv, - self.down.weight_scale_inv) - self.gate_up.update_weight(gate_up_weights, gate_up_scale) - self.down.update_weight(down_weights, down_scale) - - def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): - state = { - 'hidden_states': hidden_states, - 'topk_idx': topk_idx, - 'topk_weights': topk_weights, - 'moe_type': MoeType.Default, - } - recv_state = self.dispatch(state) - gemm_state = self.gemm(recv_state) - out_state = self.combine(gemm_state) - return out_state['hidden_states'] - - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): - - if self.tp > 1 and self.tp_mode == TPMode.DP_TP: - return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx) - else: - return self.forward_default(hidden_states, topk_weights, topk_idx) - - def before_dispatch(self, state: Dict): - moe_type = state['moe_type'] - if moe_type == MoeType.DSAsyncPrefill: - fusedmoe = self.fusedmoe_build(low_latency_mode=False) - state['fusedmoe'] = fusedmoe - if hasattr(fusedmoe, 'per_token_group_quant_fp8'): - state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states']) - previous_event = fusedmoe.capture() - state['previous_event'] = previous_event - return state - - def dispatch(self, state: Dict): - moe_type = state['moe_type'] - if moe_type == MoeType.DSAsyncPrefill: - fusedmoe = state['fusedmoe'] - previous_event = state['previous_event'] - ( - recv_hidden_states, - recv_topk_idx, - recv_topk_weights, - recv_tokens_per_expert, - handle, - event, - ) = fusedmoe.dispatch_async(state['hidden_states'], - state['topk_idx'], - state['topk_weights'], - previous_event=previous_event, - async_finish=True) - recv_state = { - 'fusedmoe': fusedmoe, - 'recv_hidden_states': recv_hidden_states, - 'recv_topk_idx': recv_topk_idx, - 'recv_topk_weights': recv_topk_weights, - 'recv_tokens_per_expert': recv_tokens_per_expert, - 'handle': handle, - 'event': event, - 'num_experts': self.num_experts, - 'moe_type': state['moe_type'] - } - elif moe_type == MoeType.DSAsyncDecode: - fusedmoe = self.fusedmoe_build(low_latency_mode=True) - use_event = False - (recv_hidden_states, recv_expert_count, handle, event, - hook) = fusedmoe.dispatch_async(state['hidden_states'], - state['topk_idx'], - use_fp8=True, - async_finish=use_event) - recv_state = { - 'fusedmoe': fusedmoe, - 'recv_hidden_states': recv_hidden_states, - 'recv_expert_count': recv_expert_count, - 'topk_idx': state['topk_idx'], - 'topk_weights': state['topk_weights'], - 'raw_hidden_shape': state['raw_hidden_shape'], - 'handle': handle, - 'moe_type': state['moe_type'] - } - if use_event: - recv_state['event'] = event - else: - recv_state['hook'] = hook - else: # MoeType.Default - hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], - state['topk_weights'], - state['topk_idx'], - group=self.gather_group) - recv_state = { - 'hidden_states': hidden_states, - 'topk_idx': topk_idx, - 'topk_weights': topk_weights, - 'moe_type': state['moe_type'] - } - return recv_state - - def gemm(self, state: Dict): - moe_type = state['moe_type'] - if moe_type == MoeType.DSAsyncPrefill: - if (state['recv_hidden_states'][0] - if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0: - state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, - self.gate_up.weight_scale_inv, - self.down.weight, - self.down.weight_scale_inv) - gemm_state = { - 'fusedmoe': state['fusedmoe'], - 'hidden_states': state['recv_hidden_states'], - 'handle': state['handle'], - 'moe_type': state['moe_type'] - } - elif moe_type == MoeType.DSAsyncDecode: - state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, - self.gate_up.weight_scale_inv, - self.down.weight, - self.down.weight_scale_inv) - gemm_state = { - 'fusedmoe': state['fusedmoe'], - 'hidden_states': state['recv_hidden_states'], - 'topk_idx': state['topk_idx'], - 'topk_weights': state['topk_weights'], - 'handle': state['handle'], - 'moe_type': state['moe_type'] - } - else: # MoeType.Default - hidden_states = self.impl.forward(state['hidden_states'], - state['topk_weights'], - state['topk_idx'], - self.gate_up.weight, - self.gate_up.weight_scale_inv, - self.down.weight, - self.down.weight_scale_inv, - gate_up_bias=self.gate_up.bias, - down_bias=self.down.bias, - expert_list=self.expert_list, - act_func=self.act_func) - gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']} - return gemm_state - - def combine(self, state: Dict): - moe_type = state['moe_type'] - if moe_type == MoeType.DSAsyncPrefill: - fusedmoe = state['fusedmoe'] - previous_event = fusedmoe.capture() - out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'], - state['handle'], - previous_event=previous_event, - async_finish=True) - out_state = { - 'fusedmoe': state['fusedmoe'], - 'hidden_states': out_hidden_states, - 'event': event, - 'moe_type': state['moe_type'] - } - elif moe_type == MoeType.DSAsyncDecode: - fusedmoe = state['fusedmoe'] - use_event = False - out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'], - state['topk_idx'], - state['topk_weights'], - state['handle'], - async_finish=use_event) - out_state = { - 'fusedmoe': state['fusedmoe'], - 'hidden_states': out_hidden_states, - 'moe_type': state['moe_type'] - } - if use_event: - out_state['event'] = event - else: - out_state['hook'] = hook - else: # MoeType.Default - if self.all_reduce: - state['hidden_states'] = _moe_reduce(state['hidden_states'], - rank=self.tp_rank, - tp_mode=self.tp_mode, - group=self.tp_group) - out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']} - return out_state - - def wait(self, state): - if state.get('event', None) is not None: - state['fusedmoe'].wait(state['event']) - return True - elif state.get('hook', None) is not None: - state['hook']() - return True - else: - return False - - def renormalize(self, topk_weights): - return self.impl.do_renormalize(topk_weights) - - def fusedmoe_build(self, low_latency_mode: bool = False): - return self.impl.fusedmoe_build(low_latency_mode) - - -def build_fused_moe( - hidden_dim: int, - ffn_dim: int, - num_experts: int, - top_k: int, - bias: bool = False, - renormalize: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - all_reduce: bool = True, - enable_ep: bool = False, - quant_config: Any = None, - layer_idx: int = 0, - act_func: Callable = None, -): - """Fused moe builder.""" - - if quant_config is None: - return FusedMoE( - hidden_dim=hidden_dim, - ffn_dim=ffn_dim, - num_experts=num_experts, - top_k=top_k, - bias=bias, - renormalize=renormalize, - dtype=dtype, - device=device, - all_reduce=all_reduce, - enable_ep=enable_ep, - act_func=act_func, - ) - - quant_method = quant_config['quant_method'] - if quant_method == 'smooth_quant': - assert not bias, 'Quant model does not support bias for now.' - assert act_func is None, ('Quant model does not support activation function for now.') - quant_dtype = eval('torch.' + quant_config.get('quant_dtype', 'int8')) - return FusedMoEW8A8( - hidden_dim=hidden_dim, - ffn_dim=ffn_dim, - num_experts=num_experts, - top_k=top_k, - renormalize=renormalize, - dtype=dtype, - quant_dtype=quant_dtype, - device=device, - all_reduce=all_reduce, - enable_ep=enable_ep, - ) - elif quant_method == 'fp8': - fmt = quant_config.get('fmt', 'e4m3') - if fmt == 'e4m3': - fp8_dtype = torch.float8_e4m3fn - elif fmt == 'e5m2': - fp8_dtype = torch.float8_e5m2 - else: - raise TypeError(f'Unsupported fp8 fmt: {fmt}') - return FusedMoEBlockedF8( - hidden_dim=hidden_dim, - ffn_dim=ffn_dim, - num_experts=num_experts, - top_k=top_k, - bias=bias, - renormalize=renormalize, - fp8_dtype=fp8_dtype, - dtype=dtype, - device=device, - all_reduce=all_reduce, - enable_ep=enable_ep, - layer_idx=layer_idx, - act_func=act_func, - ) - else: - raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/moe/__init__.py b/lmdeploy/pytorch/nn/moe/__init__.py new file mode 100644 index 0000000000..6f47c2087c --- /dev/null +++ b/lmdeploy/pytorch/nn/moe/__init__.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Callable, Optional + +import torch + +from .base import MoeType, SoftmaxTopK # noqa: F401 + + +def build_fused_moe( + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + bias: bool = False, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + enable_ep: bool = False, + quant_config: Any = None, + layer_idx: int = 0, + act_func: Callable = None, +): + """Fused moe builder.""" + + if quant_config is None: + from .default import FusedMoE + return FusedMoE( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + bias=bias, + renormalize=renormalize, + dtype=dtype, + device=device, + all_reduce=all_reduce, + layer_idx=layer_idx, + act_func=act_func, + ) + + quant_method = quant_config['quant_method'] + if quant_method == 'smooth_quant': + assert not bias, 'Quant model does not support bias for now.' + assert act_func is None, ('Quant model does not support activation function for now.') + quant_dtype = eval('torch.' + quant_config.get('quant_dtype', 'int8')) + from .w8a8 import FusedMoEW8A8 + return FusedMoEW8A8( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + renormalize=renormalize, + dtype=dtype, + quant_dtype=quant_dtype, + device=device, + all_reduce=all_reduce, + ) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + from .blocked_fp8 import FusedMoEBlockedF8 + return FusedMoEBlockedF8( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + bias=bias, + renormalize=renormalize, + fp8_dtype=fp8_dtype, + dtype=dtype, + device=device, + all_reduce=all_reduce, + layer_idx=layer_idx, + act_func=act_func, + ) + else: + raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/moe/base.py b/lmdeploy/pytorch/nn/moe/base.py new file mode 100644 index 0000000000..3df3d8149a --- /dev/null +++ b/lmdeploy/pytorch/nn/moe/base.py @@ -0,0 +1,322 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from enum import Enum, auto +from typing import Callable, Dict, List, Optional + +import torch +import torch.nn as nn + +import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.config import TPMode +from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager + + +class MoeType(Enum): + """Batch ecex type.""" + Default = auto() + DSAsyncDecode = auto() + DSAsyncPrefill = auto() + + +class SoftmaxTopK(nn.Module): + """Softmax topk.""" + + def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1): + super().__init__() + self.top_k = top_k + impl_builder = get_backend().get_layer_impl_builder(OpType.SoftmaxTopK) + self.impl = impl_builder.build(top_k, dim, n_groups=n_groups) + + def forward(self, x: torch.Tensor): + """forward.""" + return self.impl.forward(x) + + +def update_dims(hidden_dim: int, ffn_dim: int): + """Update dims.""" + world_size, _ = get_tp_world_rank('moe') + assert ffn_dim % world_size == 0 + ffn_dim = ffn_dim // world_size + return hidden_dim, ffn_dim + + +def split_size(size: int, world_size: int, align: int): + size = size // align + assert size >= world_size + base = size // world_size + remain = size % world_size + split_size = [base + 1] * remain + [base] * (world_size - remain) + split_size = [s * align for s in split_size] + return split_size + + +def moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None): + dist_config = get_dist_manager().current_config() + tp = dist_config.moe_tp + if tp == 1: + return hidden_states, topk_weights, topk_ids + + tp_mode = dist_config.moe_tp_mode + if tp_mode == TPMode.DEFAULT: + return hidden_states, topk_weights, topk_ids + elif tp_mode == TPMode.DP_TP: + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_sizes = dp_meta.moe_tp_sizes + hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group) + topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group) + topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group) + else: + raise RuntimeError('Not supported.') + + return hidden_states, topk_weights, topk_ids + + +def moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None): + dist_config = get_dist_manager().current_config() + if dist_config.moe_tp == 1: + return ret + + if tp_mode == TPMode.DEFAULT: + dist.all_reduce(ret, group=group) + return ret + elif tp_mode == TPMode.DP_TP: + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_size = dp_meta.moe_tp_sizes + ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group) + return ret + else: + raise RuntimeError('Not supported.') + + +class MoEForwardDPTP: + + def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): + """MoE forward dp tp.""" + self.gemm_func = gemm_func + self.dist_ctx = get_dist_manager().current_context() + self.dist_config = self.dist_ctx.dist_config + self.tp = self.dist_config.moe_tp + self.attn_tp = self.dist_config.attn_tp + + tp_group = self.dist_ctx.moe_tp_group + self.rank = tp_group.rank + self.gather_rank = self.rank // self.attn_tp + self.gather_group = tp_group.gpu_gather_group + self.tp_group = tp_group.gpu_group + self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 + + def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + tp_sizes: List[int]): + """All gather.""" + hidden_states, h0 = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) + topk_weights, h1 = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True) + topk_ids, h2 = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True) + return hidden_states, topk_weights, topk_ids, (h0, h1, h2) + + def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): + """Reduce scatter.""" + hidden_states_list = list(hidden_states.split(tp_sizes, -2)) + cur_out_states = hidden_states_list[self.gather_rank] + out_states.copy_(cur_out_states) + hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] + hidden_states_list[self.rank] = out_states + handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) + return out_states, handle + + def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + output_states: torch.Tensor, tp_sizes: List[int], handles: List[dist.Work]): + """Gemm and reduce scatter.""" + for handle in handles: + handle.wait() + cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids) + return self.reduce_scatter(cur_out, output_states, tp_sizes) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor): + """forward.""" + + def __slice_tensor(tensor: torch.Tensor, slice_size: int): + """Slice tensor.""" + cur_tensor = tensor[:slice_size] + tensor = tensor[slice_size:] + return cur_tensor, tensor + + def __slice_and_gather(): + """Slice and gather.""" + nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states + cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) + tp_sizes -= cur_tp_sizes + cur_tp_sizes = cur_tp_sizes.tolist() + + slice_size = cur_tp_sizes[self.gather_rank] + cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) + cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size) + cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size) + cur_output, output_states = __slice_tensor(output_states, slice_size) + + # all gather + cur_hidden_states, cur_topk_weights, cur_topk_ids, handles = self.all_gather( + cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes) + return dict(hidden_states=cur_hidden_states, + topk_weights=cur_topk_weights, + topk_ids=cur_topk_ids, + output_states=cur_output, + handles=handles, + tp_sizes=cur_tp_sizes) + + step_ctx = get_step_ctx_manager().current_context() + tp_sizes = step_ctx.dp_meta.moe_tp_sizes + tp_sizes = torch.tensor(tp_sizes) + max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) + + output_states = torch.empty_like(hidden_states) + return_states = output_states + + # pre + cur_inputs = __slice_and_gather() + + out_handles = [] + # main loop + while tp_sizes.sum() > 0: + next_inputs = __slice_and_gather() + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + out_handles.append(handle) + cur_inputs = next_inputs + + # post + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + out_handles.append(handle) + for handle in out_handles: + handle.wait() + return return_states + + +def _renormalize(topk_weights: torch.Tensor, renormalize: bool): + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if not topk_weights.is_contiguous(): + topk_weights = topk_weights.contiguous() + return topk_weights + + +@dataclass +class DispatchInputs: + """Dispatch inputs.""" + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_idx: torch.LongTensor + moe_type: MoeType = MoeType.Default + + @classmethod + def from_dict(cls, input: Dict): + """From dict.""" + assert ['hidden_states', 'topk_weights', 'topk_idx'] in input + moe_type = input.get('moe_type', MoeType.Default) + return cls( + hidden_states=input['hidden_states'], + topk_weights=input['topk_weights'], + topk_idx=input['topk_idx'], + moe_type=moe_type, + ) + + def to_dict(self) -> Dict: + """To dict.""" + return { + 'hidden_states': self.hidden_states, + 'topk_weights': self.topk_weights, + 'topk_idx': self.topk_idx, + 'moe_type': self.moe_type, + } + + +class FusedMoEBase(nn.Module): + """Fused MoE base.""" + + def __init__(self, tp: int, tp_mode: TPMode, do_renormalize: bool): + super().__init__() + self.tp = tp + self.tp_mode = tp_mode + self.do_renormalize = do_renormalize + + def init_dist_args(self, all_reduce: bool): + """Init tp args.""" + dist_ctx = get_dist_manager().current_context() + dist_cfg = dist_ctx.dist_config + _, tp_mode = dist_cfg.get_tp_by_layer('moe') + tp, tp_rank = get_tp_world_rank('moe') + all_reduce = all_reduce if tp > 1 else False + + self.ep = dist_cfg.ep + self.tp = tp + self.tp_rank = tp_rank + self.tp_mode = tp_mode + self.all_reduce = all_reduce + self.tp_group = dist_ctx.moe_tp_group.gpu_group + self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group + + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def __gemm_func(hidden_states, topk_weights, topk_ids): + return self.gemm( + dict( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_idx=topk_ids, + moe_type=MoeType.Default, + ))['hidden_states'] + + self._forward_dptp = MoEForwardDPTP(__gemm_func) + else: + self._forward_dptp = None + + def before_dispatch(self, state: DispatchInputs): + """Before dispatch.""" + raise NotImplementedError + + def dispatch(self, state: Dict): + """dispatch.""" + raise NotImplementedError + + def gemm(self, state: Dict): + """gemm.""" + raise NotImplementedError + + def combine(self, state: Dict): + """combine.""" + raise NotImplementedError + + def wait(self, state: Dict): + """wait.""" + raise NotImplementedError + + @property + def forward_dptp(self) -> MoEForwardDPTP: + """Forward dptp.""" + raise NotImplementedError + + def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + """Default forward.""" + state = { + 'hidden_states': hidden_states, + 'topk_idx': topk_idx, + 'topk_weights': topk_weights, + 'moe_type': MoeType.Default, + } + recv_state = self.dispatch(state) + gemm_state = self.gemm(recv_state) + out_state = self.combine(gemm_state) + return out_state['hidden_states'] + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + """forward.""" + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx) + else: + return self.forward_default(hidden_states, topk_weights, topk_idx) + + def renormalize(self, topk_weights): + """renormalize.""" + return _renormalize(topk_weights, self.do_renormalize) diff --git a/lmdeploy/pytorch/nn/moe/blocked_fp8.py b/lmdeploy/pytorch/nn/moe/blocked_fp8.py new file mode 100644 index 0000000000..84d6d43a54 --- /dev/null +++ b/lmdeploy/pytorch/nn/moe/blocked_fp8.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional + +import torch + +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank + +from ..quant_utils import quant_blocked_fp8 +from ..utils import div_up +from .base import DispatchInputs, FusedMoEBase, MoEForwardDPTP, MoeType, moe_gather_inputs, moe_reduce +from .base import split_size as _split_size +from .default import LinearWeights + + +class LinearWeightsBlockedF8(LinearWeights): + """Fused moe linear blocked fp8 weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + block_size: int, + dtype: torch.dtype, + device: torch.device, + bias: bool = False, + expert_list: List[int] = None): + super().__init__(num_experts=num_experts, + in_features=in_features, + out_features=out_features, + weight_type=weight_type, + dtype=dtype, + device=device, + bias=bias, + expert_list=expert_list) + self.block_size = block_size + weight_scale_inv = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)), + dtype=torch.float32, + device=device) + weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) + self.register_parameter('weight_scale_inv', weight_scale_inv) + + if self.ep: + self.weight._base_weight_loader = self.weight.weight_loader + self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep + else: + self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8 + self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp + self.weight.weight_loader = self.weight_loader_with_quant + + def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor): + """Update weight.""" + super().update_weight(weight=weight) + weight_loader = self.weight_scale_inv.weight_loader + weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) + weight_scale_inv.weight_loader = weight_loader + self.register_parameter('weight_scale_inv', weight_scale_inv) + + def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + expert_list = self.expert_list + if expert_id not in expert_list: + return + expert_ids = self.expert_map[expert_id] + for expert_id in expert_ids: + self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) + + def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int): + """Chunk with align.""" + split_size = _split_size(weight.size(dim), world_size, align) + return weight.split(split_size, dim=dim)[rank] + + def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """Weight loader.""" + world_size, rank = get_tp_world_rank('moe') + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'down': + param_data = param.data[expert_id] + # weight is not contiguous, chunk and copy in cpu is slow + weight = loaded_weight.to(param_data.device) + if weight.dim() > 1: + weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size) + elif weight.dim() == 1 and rank != 0: + # bias with rank>0 should be 0 + weight = torch.zeros_like(weight) + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + + def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """Weight loader scale tp.""" + world_size, rank = get_tp_world_rank('moe') + block_size = self.block_size + half_out = self.half_out // block_size + if shard_id == 'gate': + param_data = param.data[expert_id, :half_out] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) + elif shard_id == 'up': + param_data = param.data[expert_id, half_out:] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) + elif shard_id == 'down': + param_data = param.data[expert_id] + loaded_weight = loaded_weight.to(param_data.device) + weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1) + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + + def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """Weight load with quant.""" + if loaded_weight.dtype != param.dtype: + # quant loaded weight + quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device), param.dtype, self.block_size) + self.weight._base_weight_loader(self.weight, quanted_weight, expert_id, shard_id) + self.weight_scale_inv.weight_loader(self.weight_scale_inv, scaling, expert_id, shard_id) + else: + return self.weight._base_weight_loader(param, loaded_weight, expert_id, shard_id) + + +class FusedMoEBlockedF8(FusedMoEBase): + """Fused moe blocked f8.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + bias: bool = False, + renormalize: bool = False, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + layer_idx: int = 0, + act_func: Callable = None): + + device = device or torch.device('cpu') + dtype = dtype or torch.float16 + # init distributed tp arguments + self.block_size = 128 + self.init_dist_args(all_reduce) + + super().__init__( + tp=self.tp, + tp_mode=self.tp_mode, + do_renormalize=renormalize, + ) + + dist_ctx = get_dist_manager().current_context() + self.ep_size, rank = get_ep_world_rank() + impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8) + self.impl = impl_builder.build(top_k, + num_experts, + hidden_dim, + renormalize, + block_size=self.block_size, + ep_size=self.ep_size, + ep_group=dist_ctx.ep_gpu_group, + out_dtype=dtype, + layer_idx=layer_idx, + custom_gateup_act=act_func is not None) + + if self.ep_size > 1: + expert_list = self.impl.ep_expert_list(self.ep_size, rank) + num_experts = len(expert_list) + else: + hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size) + expert_list = None + self.expert_list = expert_list + + # create weights + self.gate_up = LinearWeightsBlockedF8(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + bias=bias, + expert_list=expert_list) + self.down = LinearWeightsBlockedF8(num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + bias=bias, + expert_list=expert_list) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + self.act_func = act_func + + @staticmethod + def _update_args(hidden_dim: int, ffn_dim: int, align: int): + world_size, rank = get_tp_world_rank('moe') + split_size = _split_size(ffn_dim, world_size, align) + ffn_dim = split_size[rank] + return hidden_dim, ffn_dim + + def update_weights(self): + """Update weights.""" + (gate_up_weights, down_weights, gate_up_scale, + down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.weight_scale_inv, + self.down.weight_scale_inv) + self.gate_up.update_weight(gate_up_weights, gate_up_scale) + self.down.update_weight(down_weights, down_scale) + + def before_dispatch(self, state: DispatchInputs): + """Before dispatch.""" + if not isinstance(state, Dict): + state = state.to_dict() + + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = self.fusedmoe_build(low_latency_mode=False) + state['fusedmoe'] = fusedmoe + if hasattr(fusedmoe, 'per_token_group_quant_fp8'): + state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states']) + previous_event = fusedmoe.capture() + state['previous_event'] = previous_event + return state + + def dispatch(self, state: Dict): + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = state['fusedmoe'] + previous_event = state['previous_event'] + ( + recv_hidden_states, + recv_topk_idx, + recv_topk_weights, + recv_tokens_per_expert, + handle, + event, + ) = fusedmoe.dispatch_async(state['hidden_states'], + state['topk_idx'], + state['topk_weights'], + previous_event=previous_event, + async_finish=True) + recv_state = { + 'fusedmoe': fusedmoe, + 'recv_hidden_states': recv_hidden_states, + 'recv_topk_idx': recv_topk_idx, + 'recv_topk_weights': recv_topk_weights, + 'recv_tokens_per_expert': recv_tokens_per_expert, + 'handle': handle, + 'event': event, + 'num_experts': self.num_experts, + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + fusedmoe = self.fusedmoe_build(low_latency_mode=True) + use_event = False + (recv_hidden_states, recv_expert_count, handle, event, + hook) = fusedmoe.dispatch_async(state['hidden_states'], + state['topk_idx'], + use_fp8=True, + async_finish=use_event) + recv_state = { + 'fusedmoe': fusedmoe, + 'recv_hidden_states': recv_hidden_states, + 'recv_expert_count': recv_expert_count, + 'topk_idx': state['topk_idx'], + 'topk_weights': state['topk_weights'], + 'raw_hidden_shape': state['raw_hidden_shape'], + 'handle': handle, + 'moe_type': state['moe_type'] + } + if use_event: + recv_state['event'] = event + else: + recv_state['hook'] = hook + else: # MoeType.Default + hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + group=self.gather_group) + recv_state = { + 'hidden_states': hidden_states, + 'topk_idx': topk_idx, + 'topk_weights': topk_weights, + 'moe_type': state['moe_type'] + } + return recv_state + + def gemm(self, state: Dict): + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + if (state['recv_hidden_states'][0] + if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0: + state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, + self.gate_up.weight_scale_inv, + self.down.weight, + self.down.weight_scale_inv) + gemm_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': state['recv_hidden_states'], + 'handle': state['handle'], + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, + self.gate_up.weight_scale_inv, + self.down.weight, + self.down.weight_scale_inv) + gemm_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': state['recv_hidden_states'], + 'topk_idx': state['topk_idx'], + 'topk_weights': state['topk_weights'], + 'handle': state['handle'], + 'moe_type': state['moe_type'] + } + else: # MoeType.Default + hidden_states = self.impl.forward(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + self.gate_up.weight, + self.gate_up.weight_scale_inv, + self.down.weight, + self.down.weight_scale_inv, + gate_up_bias=self.gate_up.bias, + down_bias=self.down.bias, + expert_list=self.expert_list, + act_func=self.act_func) + gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']} + return gemm_state + + def combine(self, state: Dict): + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = state['fusedmoe'] + previous_event = fusedmoe.capture() + out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'], + state['handle'], + previous_event=previous_event, + async_finish=True) + out_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': out_hidden_states, + 'event': event, + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + fusedmoe = state['fusedmoe'] + use_event = False + out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'], + state['topk_idx'], + state['topk_weights'], + state['handle'], + async_finish=use_event) + out_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': out_hidden_states, + 'moe_type': state['moe_type'] + } + if use_event: + out_state['event'] = event + else: + out_state['hook'] = hook + else: # MoeType.Default + if self.all_reduce: + state['hidden_states'] = moe_reduce(state['hidden_states'], + rank=self.tp_rank, + tp_mode=self.tp_mode, + group=self.tp_group) + out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']} + return out_state + + def wait(self, state): + if state.get('event', None) is not None: + state['fusedmoe'].wait(state['event']) + return True + elif state.get('hook', None) is not None: + state['hook']() + return True + else: + return False + + @property + def forward_dptp(self) -> MoEForwardDPTP: + """Forward dptp.""" + return self._forward_dptp + + def fusedmoe_build(self, low_latency_mode: bool = False): + return self.impl.fusedmoe_build(low_latency_mode) diff --git a/lmdeploy/pytorch/nn/moe/default.py b/lmdeploy/pytorch/nn/moe/default.py new file mode 100644 index 0000000000..90b41fe0b2 --- /dev/null +++ b/lmdeploy/pytorch/nn/moe/default.py @@ -0,0 +1,375 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Callable, Dict, List, Optional + +import torch +from torch import nn + +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank + +from .base import DispatchInputs, FusedMoEBase, MoEForwardDPTP, MoeType, moe_gather_inputs, moe_reduce, update_dims + + +class LinearWeights(nn.Module): + """Fused moe linear weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + dtype: torch.dtype, + device: torch.device, + bias: bool = False, + expert_list: Optional[List[int]] = None): + super().__init__() + weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.register_parameter('weight', weight) + + if bias: + bias = torch.empty((num_experts, out_features), dtype=dtype, device=device) + bias = torch.nn.Parameter(bias, requires_grad=False) + self.register_parameter('bias', bias) + else: + self.bias = None + + self.ep = expert_list is not None + self.expert_list = expert_list + self.weight_type = weight_type + self.half_out = out_features // 2 + + self.setup_weight_loader() + + def setup_weight_loader(self): + """Setup weight loader.""" + if self.expert_list is not None: + self.expert_map = defaultdict(list) + for idx, eid in enumerate(self.expert_list): + self.expert_map[eid].append(idx) + self.weight.weight_loader = self.weight_loader_ep + if self.bias is not None: + self.bias.weight_loader = self.weight_loader_ep + else: + self.weight.weight_loader = self.weight_loader_tp + if self.bias is not None: + self.bias.weight_loader = self.weight_loader_tp + + def update_weight(self, weight: torch.Tensor): + """Update weight.""" + weight_loader = self.weight.weight_loader + weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = weight_loader + self.register_parameter('weight', weight) + + def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): + """Weight loader.""" + world_size, rank = get_tp_world_rank('moe') + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'down': + param_data = param.data[expert_id] + # weight is not contiguous, chunk and copy in cpu is slow + weight = loaded_weight.to(param_data.device) + if weight.dim() > 1: + weight = weight.chunk(world_size, dim=1)[rank] + elif weight.dim() == 1 and rank != 0: + # bias with rank>0 should be 0 + weight = torch.zeros_like(weight) + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + + def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): + """Weight loader.""" + expert_list = self.expert_list + if expert_id not in expert_list: + return + + expert_map = self.expert_map + param_ids = expert_map[expert_id] + for param_id in param_ids: + if shard_id == 'gate': + param_data = param.data[param_id, :self.half_out] + elif shard_id == 'up': + param_data = param.data[param_id, self.half_out:] + elif shard_id == 'down': + param_data = param.data[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(loaded_weight) + + +class FusedMoE(FusedMoEBase): + """Fused MoE.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + bias: bool = False, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + layer_idx: int = 0, + act_func: Callable = None): + + device = device or torch.device('cpu') + dtype = dtype or torch.float16 + # init distributed tp arguments + self.init_dist_args(all_reduce) + + super().__init__( + tp=self.tp, + tp_mode=self.tp_mode, + do_renormalize=renormalize, + ) + + # create implementation + dist_ctx = get_dist_manager().current_context() + self.ep_size, rank = get_ep_world_rank() + impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) + self.impl = impl_builder.build( + top_k, + num_experts, + renormalize, + hidden_dim=hidden_dim, + ep_size=self.ep_size, + ep_group=dist_ctx.ep_gpu_group, + layer_idx=layer_idx, + ) + + # create weights + if self.ep_size > 1: + expert_list = self.impl.ep_expert_list(self.ep_size, rank) + num_experts = len(expert_list) + else: + hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + self.gate_up = LinearWeights(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + dtype=dtype, + device=device, + bias=bias, + expert_list=expert_list) + self.down = LinearWeights( + num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + dtype=dtype, + device=device, + bias=bias, + expert_list=expert_list, + ) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + self.act_func = act_func + + def update_weights(self): + """Update weights.""" + gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight) + self.gate_up.update_weight(gate_up_weights) + self.down.update_weight(down_weights) + + def before_dispatch(self, state: DispatchInputs): + """Before dispatch.""" + if not isinstance(state, Dict): + state = state.to_dict() + + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = self.fusedmoe_build(low_latency_mode=False) + state['fusedmoe'] = fusedmoe + previous_event = fusedmoe.capture() + state['previous_event'] = previous_event + return state + + def dispatch(self, state: Dict): + """dispatch.""" + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = state['fusedmoe'] + previous_event = state['previous_event'] + ( + recv_hidden_states, + recv_topk_idx, + recv_topk_weights, + recv_tokens_per_expert, + handle, + event, + ) = fusedmoe.dispatch_async(state['hidden_states'], + state['topk_idx'], + state['topk_weights'], + previous_event=previous_event, + async_finish=True) + recv_state = { + 'fusedmoe': fusedmoe, + 'recv_hidden_states': recv_hidden_states, + 'recv_topk_idx': recv_topk_idx, + 'recv_topk_weights': recv_topk_weights, + 'recv_tokens_per_expert': recv_tokens_per_expert, + 'handle': handle, + 'event': event, + 'num_experts': self.num_experts, + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + fusedmoe = self.fusedmoe_build(low_latency_mode=True) + use_event = False + (recv_hidden_states, recv_expert_count, handle, event, + hook) = fusedmoe.dispatch_async(state['hidden_states'], + state['topk_idx'], + use_fp8=False, + async_finish=use_event) + recv_state = { + 'fusedmoe': fusedmoe, + 'recv_hidden_states': recv_hidden_states, + 'recv_expert_count': recv_expert_count, + 'topk_idx': state['topk_idx'], + 'topk_weights': state['topk_weights'], + 'raw_hidden_shape': state['raw_hidden_shape'], + 'handle': handle, + 'moe_type': state['moe_type'] + } + if use_event: + recv_state['event'] = event + else: + recv_state['hook'] = hook + elif moe_type == MoeType.Default: + hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + group=self.gather_group) + recv_state = { + 'hidden_states': hidden_states, + 'topk_idx': topk_idx, + 'topk_weights': topk_weights, + 'moe_type': moe_type + } + else: + raise NotImplementedError(f'Not supported moe type: {moe_type}') + return recv_state + + def gemm(self, state: Dict): + """gemm.""" + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + if (state['recv_hidden_states'][0] + if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0: + state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, + self.gate_up.weight_scale_inv, + self.down.weight, + self.down.weight_scale_inv) + gemm_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': state['recv_hidden_states'], + 'handle': state['handle'], + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight, + self.gate_up.weight_scale_inv, + self.down.weight, + self.down.weight_scale_inv) + gemm_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': state['recv_hidden_states'], + 'topk_idx': state['topk_idx'], + 'topk_weights': state['topk_weights'], + 'handle': state['handle'], + 'moe_type': state['moe_type'] + } + else: + hidden_states = state['hidden_states'] + topk_weights = state['topk_weights'] + topk_ids = state['topk_idx'] + + hidden_states = self.impl.forward(hidden_states, + topk_weights, + topk_ids, + self.gate_up.weight, + self.down.weight, + self.gate_up.bias, + self.down.bias, + self.expert_list, + act_func=self.act_func) + gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']} + return gemm_state + + def combine(self, state: Dict): + """combine.""" + moe_type = state['moe_type'] + if moe_type == MoeType.DSAsyncPrefill: + fusedmoe = state['fusedmoe'] + previous_event = fusedmoe.capture() + out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'], + state['handle'], + previous_event=previous_event, + async_finish=True) + out_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': out_hidden_states, + 'event': event, + 'moe_type': state['moe_type'] + } + elif moe_type == MoeType.DSAsyncDecode: + fusedmoe = state['fusedmoe'] + use_event = False + out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'], + state['topk_idx'], + state['topk_weights'], + state['handle'], + async_finish=use_event) + out_state = { + 'fusedmoe': state['fusedmoe'], + 'hidden_states': out_hidden_states, + 'moe_type': state['moe_type'] + } + if use_event: + out_state['event'] = event + else: + out_state['hook'] = hook + elif moe_type == MoeType.Default: + if self.all_reduce: + state['hidden_states'] = moe_reduce(state['hidden_states'], + rank=self.tp_rank, + tp_mode=self.tp_mode, + group=self.tp_group) + out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type} + else: + raise NotImplementedError(f'Not supported moe type: {moe_type}') + return out_state + + def wait(self, state: Dict): + """wait.""" + if state.get('event', None) is not None: + state['fusedmoe'].wait(state['event']) + return True + elif state.get('hook', None) is not None: + state['hook']() + return True + else: + return False + + @property + def forward_dptp(self) -> MoEForwardDPTP: + """Forward dptp.""" + return self._forward_dptp + + def fusedmoe_build(self, low_latency_mode: bool = False): + return self.impl.fusedmoe_build(low_latency_mode) diff --git a/lmdeploy/pytorch/nn/moe/w8a8.py b/lmdeploy/pytorch/nn/moe/w8a8.py new file mode 100644 index 0000000000..c92e03be46 --- /dev/null +++ b/lmdeploy/pytorch/nn/moe/w8a8.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch + +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.distributed import get_tp_world_rank + +from .base import FusedMoEBase, MoEForwardDPTP, MoeType, moe_gather_inputs, moe_reduce, update_dims +from .default import LinearWeights + + +class LinearWeightsW8A8(LinearWeights): + """Fused moe linear w8a8 weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + device: torch.device, + expert_list: List[int] = None, + quant_dtype: torch.dtype = torch.int8): + super().__init__( + num_experts=num_experts, + in_features=in_features, + out_features=out_features, + weight_type=weight_type, + dtype=quant_dtype, + device=device, + expert_list=expert_list, + ) + scale = torch.empty((num_experts, out_features, 1), dtype=torch.float32, device=device) + scale = torch.nn.Parameter(scale, requires_grad=False) + self.register_parameter('scale', scale) + + if self.ep: + self.scale.weight_loader = self.weight_loader_ep + else: + self.scale.weight_loader = self.weight_loader_scale_tp + + def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): + """Update weight.""" + super().update_weight(weight=weight) + weight_loader = self.scale.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = weight_loader + self.register_parameter('scale', scale) + + def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """Weight loader scale tp.""" + world_size, rank = get_tp_world_rank('moe') + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'down': + param_data = param.data[expert_id] + weight = loaded_weight + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + weight = weight.to(param.dtype) + param_data.copy_(weight) + + +class FusedMoEW8A8(FusedMoEBase): + """Fused moe w8a8.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + quant_dtype: Optional[torch.dtype] = torch.int8, + device: Optional[torch.device] = None, + all_reduce: bool = True): + + device = device or torch.device('cpu') + dtype = dtype or torch.float16 + # init distributed tp arguments + self.init_dist_args(all_reduce) + + # check ep + if self.ep > 1: + raise RuntimeError('FusedMoEW8A8 does not support EP mode now.') + + super().__init__( + tp=self.tp, + tp_mode=self.tp_mode, + do_renormalize=renormalize, + ) + + # create implementation + impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEW8A8) + self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype, quant_dtype=quant_dtype) + + # create weights + hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + self.gate_up = LinearWeightsW8A8(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + device=device, + expert_list=expert_list, + quant_dtype=quant_dtype) + self.down = LinearWeightsW8A8(num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + device=device, + expert_list=expert_list, + quant_dtype=quant_dtype) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + self.all_reduce = all_reduce + + def update_weights(self): + """Update weights.""" + (gate_up_weights, down_weights, gate_up_scale, + down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.scale, + self.down.scale) + self.gate_up.update_weight(gate_up_weights, gate_up_scale) + self.down.update_weight(down_weights, down_scale) + + def dispatch(self, state: Dict): + """dispatch.""" + moe_type = state['moe_type'] + if moe_type == MoeType.Default: + hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + group=self.gather_group) + recv_state = { + 'hidden_states': hidden_states, + 'topk_idx': topk_idx, + 'topk_weights': topk_weights, + 'moe_type': moe_type + } + else: + raise NotImplementedError(f'Not supported moe type: {moe_type}') + return recv_state + + def gemm(self, state: Dict): + """gemm.""" + hidden_states = state['hidden_states'] + topk_weights = state['topk_weights'] + topk_ids = state['topk_idx'] + + ret = self.impl.forward(hidden_states, topk_weights, topk_ids, self.gate_up.weight, self.gate_up.scale, + self.down.weight, self.down.scale, self.expert_list) + return dict(hidden_states=ret, moe_type=state['moe_type']) + + def combine(self, state: Dict): + """combine.""" + moe_type = state['moe_type'] + if moe_type == MoeType.Default: + if self.all_reduce: + state['hidden_states'] = moe_reduce(state['hidden_states'], + rank=self.tp_rank, + tp_mode=self.tp_mode, + group=self.tp_group) + out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type} + else: + raise NotImplementedError(f'Not supported moe type: {moe_type}') + return out_state + + def wait(self, state: Dict): + """wait.""" + raise NotImplementedError + + @property + def forward_dptp(self) -> MoEForwardDPTP: + """Forward dptp.""" + return self._forward_dptp diff --git a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py index 369862e60e..2e3929f906 100644 --- a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py +++ b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py @@ -82,3 +82,10 @@ def m_grouped_fp8_gemm_nt_masked(a, def get_mn_major_tma_aligned_tensor(x): return get_col_major_tma_aligned_tensor(x) + + +try: + from deep_gemm import m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_masked # noqa: F401 +except Exception: + logger.warning('DeepGemm bf16 grouped gemm kernels are not found. ' + 'Please upgrade DeepGemm to the latest version.')