From 8b95cfe32be57e22fc46af4cccc22c2e624e7916 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 1 Jul 2025 17:22:36 +0800 Subject: [PATCH 1/8] [quant] deepgemm-fp8w8a8-b128 quantize function --- .../common/quantization/deepgemm_quant.py | 10 +++-- .../fp8/fp8w8a8_block_quant_kernel.py | 44 +++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 8 ++-- 3 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 622a9711c..2d8f828be 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -41,11 +41,15 @@ def __init__(self): self.act_scale_suffix = None # no support for static input tensor scale for ds model. def quantize(self, weight: torch.Tensor): - - raise Exception("Not implemented") + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant + return weight_quant(weight, self.block_size) def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + if len(weights) == 3: + qweight, weight_scale, input_scale = weights + else: + qweight, weight_scale = weights + input_scale = None m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py new file mode 100644 index 000000000..0021a94da --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -0,0 +1,44 @@ +import torch +import triton +import triton.language as tl +from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant + +@triton.jit +def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n_blocks = tl.cdiv(N, BLOCK_SIZE) + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + amax = tl.max(tl.abs(x)) + + max_fp8e4m3_val = 448.0 + scale = amax / (max_fp8e4m3_val + 1e-6) + + y = (x / scale).to(y_ptr.dtype.element_ty) + + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) + + +def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.dim() == 2, 'Input tensor must have 2 dimensions' + M, N = x.size() + + y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) + + num_blocks_m = triton.cdiv(M, block_size) + num_blocks_n = triton.cdiv(N, block_size) + s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) + return y_quant, s_scales + diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 57d10bdcd..441116264 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -114,13 +114,11 @@ def _moe_ffn_edp( router_logits=router_logits, top_k=self.num_experts_per_tok, renormalize=self.norm_topk_prob, - use_grouped_topk=self.n_group, - topk_group=self.topk_group, - num_expert_group=self.n_group, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, is_prefill=infer_state.is_prefill, ) - if self.n_shared_experts is not None: - ep_output.add_(shared_output) ep_output = ep_output.view(token_num, hidden_dim) return ep_output From bccce16bc7a2062259280462fab7c3d69be72b64 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 2 Jul 2025 10:54:46 +0800 Subject: [PATCH 2/8] all --- .../meta_weights/fused_moe_weight_ep.py | 6 +- .../fp8/fp8w8a8_block_quant_kernel.py | 87 +++++++++++++++---- .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/models/qwen3_moe/model.py | 10 +++ lightllm/utils/dist_utils.py | 12 +++ 5 files changed, 99 insertions(+), 20 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index f7a24ae0f..051a11c9f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -84,16 +84,18 @@ def __init__( self.e_score_correction_bias = None self.w2_list = [None] * ep_load_expert_num self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config["scoring_func"] + self.scoring_func = network_config.get("scoring_func","softmax") self.w1 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale self.use_fp8_w8a8 = self.quant_method is not None - + network_config["n_group"] = network_config.get("n_group", 0) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.use_grouped_topk = network_config["n_group"] > 0 self.norm_topk_prob = network_config["norm_topk_prob"] self.n_group = network_config["n_group"] + network_config["topk_group"] = network_config.get("topk_group", 0) self.topk_group = network_config["topk_group"] + network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 0) self.routed_scaling_factor = network_config["routed_scaling_factor"] self.lock = threading.Lock() diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index 0021a94da..0f585dd5e 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -1,10 +1,24 @@ import torch import triton import triton.language as tl -from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant +from lightllm.utils.dist_utils import get_current_device_id +from typing import Tuple + @triton.jit -def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): +def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE) -> None: + """ + Triton kernel for weight quantization to FP8 e4m3 format. + + Args: + x_ptr: Input tensor pointer (float32) + s_ptr: Output scale tensor pointer (float32) + y_ptr: Output quantized tensor pointer (float8_e4m3fn) + M: Number of rows + N: Number of columns + BLOCK_SIZE: Size of the processing block + """ + pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) n_blocks = tl.cdiv(N, BLOCK_SIZE) @@ -17,28 +31,71 @@ def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) amax = tl.max(tl.abs(x)) - max_fp8e4m3_val = 448.0 scale = amax / (max_fp8e4m3_val + 1e-6) - y = (x / scale).to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y, mask=mask) tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) -def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.is_contiguous(), 'Input tensor must be contiguous' - assert x.dim() == 2, 'Input tensor must have 2 dimensions' - M, N = x.size() - - y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) +def weight_quant( + x: torch.Tensor, + block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + if not x.is_contiguous(): + raise ValueError("Input tensor must be contiguous") + + if not x.is_cuda: + x = x.cuda(get_current_device_id()) + + if x.dim() not in (2, 3): + raise ValueError(f"Input tensor must be 2D or 3D, got {x.dim()}D") + + # Handle 3D input by processing each batch + if x.dim() == 3: + batch_size, M, N = x.shape + y_quant = torch.empty_like(x, dtype=torch.float8_e4m3fn) + num_blocks_m = triton.cdiv(M, block_size) + num_blocks_n = triton.cdiv(N, block_size) + s_scales = torch.empty( + (batch_size, num_blocks_m, num_blocks_n), + dtype=torch.float32, + device=x.device + ) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE'])) + + for i in range(batch_size): + weight_quant_kernel[grid]( + x[i], + s_scales[i], + y_quant[i], + M, + N, + BLOCK_SIZE=block_size + ) + + return y_quant, s_scales + + # Handle 2D input + M, N = x.shape + y_quant = torch.empty_like(x, dtype=torch.float8_e4m3fn) num_blocks_m = triton.cdiv(M, block_size) num_blocks_n = triton.cdiv(N, block_size) - s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) - - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + s_scales = torch.empty( + (num_blocks_m, num_blocks_n), + dtype=torch.float32, + device=x.device + ) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) - return y_quant, s_scales - + + return y_quant.t(), s_scales.t() + + \ No newline at end of file diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 441116264..9e6acf135 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -105,9 +105,7 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - + router_logits = layer_weight.moe_gate.mm(hidden_states) ep_output = layer_weight.experts.experts( hidden_states, diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b3421a325..cdcedf378 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -5,6 +5,7 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) @@ -21,3 +22,12 @@ class Qwen3MOEModel(Qwen3TpPartModel): def __init__(self, kvargs): super().__init__(kvargs) return + + def _init_config(self): + super()._init_config() + # self.config["num_hidden_layers"] = 2 + # self.config["n_layer"] = 2 + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(256, self.config["hidden_size"]) \ No newline at end of file diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 8e8b9d286..4e3ae6c46 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -77,6 +77,18 @@ def init_vision_distributed_env(kvargs): del _a +def run_once(func): + has_run = False + def wrapper(*args, **kwargs): + nonlocal has_run + if not has_run: + has_run = True + return func(*args, **kwargs) + else: + return None + return wrapper + +@run_once def init_distributed_env(kvargs): assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" node_world_size = kvargs["world_size"] // kvargs["args"].nnodes From be25bfd39a86444417b9356c39a1bc97a465475c Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 2 Jul 2025 11:06:18 +0800 Subject: [PATCH 3/8] fix --- .../fp8/fp8w8a8_block_quant_kernel.py | 97 ++++++------------- 1 file changed, 28 insertions(+), 69 deletions(-) diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index 0f585dd5e..e2cf26425 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -2,23 +2,9 @@ import triton import triton.language as tl from lightllm.utils.dist_utils import get_current_device_id -from typing import Tuple - @triton.jit -def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE) -> None: - """ - Triton kernel for weight quantization to FP8 e4m3 format. - - Args: - x_ptr: Input tensor pointer (float32) - s_ptr: Output scale tensor pointer (float32) - y_ptr: Output quantized tensor pointer (float8_e4m3fn) - M: Number of rows - N: Number of columns - BLOCK_SIZE: Size of the processing block - """ - +def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) n_blocks = tl.cdiv(N, BLOCK_SIZE) @@ -31,71 +17,44 @@ def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE) -> None: x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) amax = tl.max(tl.abs(x)) + max_fp8e4m3_val = 448.0 scale = amax / (max_fp8e4m3_val + 1e-6) + y = (x / scale).to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y, mask=mask) tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) -def weight_quant( - x: torch.Tensor, - block_size: int = 128 -) -> Tuple[torch.Tensor, torch.Tensor]: +def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), 'Input tensor must be contiguous' + M, N = x.size() + + y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) - if not x.is_contiguous(): - raise ValueError("Input tensor must be contiguous") - - if not x.is_cuda: - x = x.cuda(get_current_device_id()) - - if x.dim() not in (2, 3): - raise ValueError(f"Input tensor must be 2D or 3D, got {x.dim()}D") - - # Handle 3D input by processing each batch - if x.dim() == 3: - batch_size, M, N = x.shape - y_quant = torch.empty_like(x, dtype=torch.float8_e4m3fn) - num_blocks_m = triton.cdiv(M, block_size) - num_blocks_n = triton.cdiv(N, block_size) - s_scales = torch.empty( - (batch_size, num_blocks_m, num_blocks_n), - dtype=torch.float32, - device=x.device - ) - - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), - triton.cdiv(N, meta['BLOCK_SIZE'])) - - for i in range(batch_size): - weight_quant_kernel[grid]( - x[i], - s_scales[i], - y_quant[i], - M, - N, - BLOCK_SIZE=block_size - ) - - return y_quant, s_scales - - # Handle 2D input - M, N = x.shape - y_quant = torch.empty_like(x, dtype=torch.float8_e4m3fn) num_blocks_m = triton.cdiv(M, block_size) num_blocks_n = triton.cdiv(N, block_size) - s_scales = torch.empty( - (num_blocks_m, num_blocks_n), - dtype=torch.float32, - device=x.device - ) - - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), - triton.cdiv(N, meta['BLOCK_SIZE'])) - + s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) - - return y_quant.t(), s_scales.t() + return y_quant, s_scales + + +def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), 'Input tensor must be contiguous' + x = x.cuda(get_current_device_id()) + if x.dim() == 3: + y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) + num_blocks_m = triton.cdiv(x.shape[1], block_size) + num_blocks_n = triton.cdiv(x.shape[2], block_size) + s_scales = torch.empty((x.shape[0], num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) + for i in range(x.shape[0]): + y_quant[i], s_scales[i] = mm_weight_quant(x[i], block_size) + return y_quant, s_scales + else: + y_quant, s_scales = mm_weight_quant(x, block_size) + return y_quant.t(), s_scales.t() \ No newline at end of file From 94aecac1939f7048a0feadbf993caa669aa1a4de Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:24:14 +0800 Subject: [PATCH 4/8] fix Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../triton_quant/fp8/fp8w8a8_block_quant_kernel.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index e2cf26425..3961ee150 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -19,9 +19,8 @@ def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): amax = tl.max(tl.abs(x)) max_fp8e4m3_val = 448.0 - scale = amax / (max_fp8e4m3_val + 1e-6) - - y = (x / scale).to(y_ptr.dtype.element_ty) + scale = amax / max_fp8e4m3_val + y = (x / (scale + 1e-6)).to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y, mask=mask) tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) From 8aed808b89c1f771240342b318a6ae987e291ce6 Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:26:40 +0800 Subject: [PATCH 5/8] fix Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightllm/models/qwen3_moe/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index cdcedf378..bac7de83e 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -30,4 +30,4 @@ def _init_config(self): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(256, self.config["hidden_size"]) \ No newline at end of file + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) \ No newline at end of file From 91c542b6c3eceab877a15ac236e14d4298a07b90 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 2 Jul 2025 11:34:41 +0800 Subject: [PATCH 6/8] fix --- lightllm/models/qwen3_moe/model.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index bac7de83e..10a505127 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -23,11 +23,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_config(self): - super()._init_config() - # self.config["num_hidden_layers"] = 2 - # self.config["n_layer"] = 2 - def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) \ No newline at end of file + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) From e2fe0a9cd162458fb4e72ebb2b5f9eaea1aef9bd Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 2 Jul 2025 14:12:13 +0800 Subject: [PATCH 7/8] fix --- .../layer_weights/meta_weights/fused_moe_weight_ep.py | 4 ++-- lightllm/common/quantization/deepgemm_quant.py | 1 + .../triton_quant/fp8/fp8w8a8_block_quant_kernel.py | 11 +++++------ .../qwen3_moe/layer_infer/transformer_layer_infer.py | 2 +- lightllm/utils/dist_utils.py | 5 ++++- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 051a11c9f..cc925525c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -84,7 +84,7 @@ def __init__( self.e_score_correction_bias = None self.w2_list = [None] * ep_load_expert_num self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config.get("scoring_func","softmax") + self.scoring_func = network_config.get("scoring_func", "softmax") self.w1 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale self.use_fp8_w8a8 = self.quant_method is not None @@ -93,7 +93,7 @@ def __init__( self.use_grouped_topk = network_config["n_group"] > 0 self.norm_topk_prob = network_config["norm_topk_prob"] self.n_group = network_config["n_group"] - network_config["topk_group"] = network_config.get("topk_group", 0) + network_config["topk_group"] = network_config.get("topk_group", 0) self.topk_group = network_config["topk_group"] network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 0) self.routed_scaling_factor = network_config["routed_scaling_factor"] diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 2d8f828be..8d14805ad 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -42,6 +42,7 @@ def __init__(self): def quantize(self, weight: torch.Tensor): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant + return weight_quant(weight, self.block_size) def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index 3961ee150..11c1897d7 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -3,6 +3,7 @@ import triton.language as tl from lightllm.utils.dist_utils import get_current_device_id + @triton.jit def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): pid_m = tl.program_id(axis=0) @@ -18,7 +19,7 @@ def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): amax = tl.max(tl.abs(x)) - max_fp8e4m3_val = 448.0 + max_fp8e4m3_val = 448.0 scale = amax / max_fp8e4m3_val y = (x / (scale + 1e-6)).to(y_ptr.dtype.element_ty) @@ -27,7 +28,7 @@ def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.is_contiguous(), "Input tensor must be contiguous" M, N = x.size() y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) @@ -36,13 +37,13 @@ def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tenso num_blocks_n = triton.cdiv(N, block_size) s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) return y_quant, s_scales def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.is_contiguous(), "Input tensor must be contiguous" x = x.cuda(get_current_device_id()) if x.dim() == 3: y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) @@ -55,5 +56,3 @@ def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, else: y_quant, s_scales = mm_weight_quant(x, block_size) return y_quant.t(), s_scales.t() - - \ No newline at end of file diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9e6acf135..2e01bc6e4 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -105,7 +105,7 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape - + router_logits = layer_weight.moe_gate.mm(hidden_states) ep_output = layer_weight.experts.experts( hidden_states, diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 4e3ae6c46..c852df86a 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -79,15 +79,18 @@ def init_vision_distributed_env(kvargs): def run_once(func): has_run = False + def wrapper(*args, **kwargs): nonlocal has_run if not has_run: has_run = True return func(*args, **kwargs) else: - return None + return None + return wrapper + @run_once def init_distributed_env(kvargs): assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" From 76795dd5d83fa13095de86268f1393175deada1b Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 2 Jul 2025 14:41:45 +0800 Subject: [PATCH 8/8] fix --- lightllm/utils/dist_utils.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index c852df86a..8e8b9d286 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -77,21 +77,6 @@ def init_vision_distributed_env(kvargs): del _a -def run_once(func): - has_run = False - - def wrapper(*args, **kwargs): - nonlocal has_run - if not has_run: - has_run = True - return func(*args, **kwargs) - else: - return None - - return wrapper - - -@run_once def init_distributed_env(kvargs): assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" node_world_size = kvargs["world_size"] // kvargs["args"].nnodes