Skip to content

Commit 27cea3a

Browse files
committed
0529
1 parent 815a698 commit 27cea3a

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

lightllm/common/quantization/w8a8_quant.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
77
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
88
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm
9+
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
10+
11+
if not HAS_LIGHTLLM_KERNEL:
12+
13+
def scaled_fp8_quant(tensor, *args, **kwargs):
14+
return light_ops.per_token_quant_bf16_fp8(tensor)
15+
16+
else:
17+
scaled_fp8_quant = vllm_ops.scaled_fp8_quant
918

1019

1120
class BaseQuantizationMethod(QuantizationMethod):
@@ -71,7 +80,7 @@ def __init__(self):
7180
def quantize(self, weight: torch.Tensor):
7281
if self.is_moe:
7382
return self.quantize_moe(weight)
74-
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
83+
qweight, weight_scale = scaled_fp8_quant(
7584
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
7685
)
7786
return qweight.transpose(0, 1), weight_scale
@@ -82,7 +91,7 @@ def quantize_moe(self, weight):
8291
weight_scales = []
8392
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_)
8493
for i in range(num_experts):
85-
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
94+
qweight, weight_scale = scaled_fp8_quant(
8695
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False
8796
)
8897
qweights[i] = qweight
@@ -91,7 +100,7 @@ def quantize_moe(self, weight):
91100
return qweights, weight_scale
92101

93102
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
94-
x_q, x_scale = vllm_ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
103+
x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
95104
m = input_tensor.shape[0]
96105
n = weights[0].shape[1]
97106
if out is None:

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor
2727
from lightllm.utils.log_utils import init_logger
2828
from lightllm.utils.envs_utils import get_env_start_args
29+
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
2930

3031
logger = init_logger(__name__)
3132

@@ -539,11 +540,9 @@ def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo
539540
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
540541
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
541542

542-
from lightllm_ppl_kernel import group8_int8kv_decode_attention
543-
544543
# group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v,
545544
# at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch)
546-
group8_int8kv_decode_attention(
545+
light_ops.group8_int8kv_decode_attention(
547546
o_tensor.view(calcu_shape1),
548547
q.view(calcu_shape1),
549548
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
23

34

45
def token_decode_attention_flash_decoding(
@@ -18,7 +19,6 @@ def token_decode_attention_flash_decoding(
1819
max_len_in_batch = infer_state.max_len_in_batch
1920
calcu_shape1 = (batch_size, q_head_num, head_dim)
2021

21-
from lightllm_ppl_int8kv_flashdecoding_kernel import group8_int8kv_flashdecoding_stage1
2222
from .flash_decoding_stage2 import flash_decode_stage2
2323

2424
o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out
@@ -30,7 +30,7 @@ def token_decode_attention_flash_decoding(
3030
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda"
3131
)
3232

33-
group8_int8kv_flashdecoding_stage1(
33+
light_ops.group8_int8kv_flashdecoding_stage1(
3434
BLOCK_SEQ,
3535
mid_o,
3636
mid_o_logexpsum,

0 commit comments

Comments
 (0)