diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py index aa3b5f61d..120880143 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -116,7 +116,7 @@ def per_token_group_quant_fp8( if HAS_SGL_KERNEL: finfo = torch.finfo(dtype) fp8_max, fp8_min = finfo.max, finfo.min - sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) + sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False) else: lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 7f0601c92..8fb832ad4 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -44,6 +44,9 @@ def get_model_architectures(model_path: str): def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) + if "llm_config" in config_json: + vocab_size = int(config_json["llm_config"]["vocab_size"]) + return vocab_size vocab_size = config_json["vocab_size"] if not isinstance(vocab_size, int): vocab_size = int(vocab_size)