Skip to content

Commit 7ea127b

Browse files
author
niushengxiao
committed
feat: add pre tensor fp8 kv quant for flashinfer
1 parent 79b8280 commit 7ea127b

13 files changed

+2158
-1403
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,13 @@
2323
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
2424
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2525
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
26+
from lightllm.utils.envs_utils import set_model_init_status
2627

2728

2829
logger = init_logger(__name__)
2930

3031
torch.backends.cudnn.enabled = True
3132

32-
g_model_init_done = False
33-
34-
35-
def get_model_init_status():
36-
# 获取模型初始化状态
37-
global g_model_init_done
38-
return g_model_init_done
39-
4033

4134
class TpPartBaseModel:
4235
# weight class
@@ -111,8 +104,7 @@ def __init__(self, kvargs):
111104
self._init_cudagraph()
112105
self._check_max_len_infer()
113106
torch.cuda.empty_cache()
114-
global g_model_init_done
115-
g_model_init_done = True
107+
set_model_init_status(True)
116108
return
117109

118110
def _init_config(self):

lightllm/common/offline_fp8_quant_mem_manager.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from lightllm.utils.dist_utils import get_global_rank
88
from lightllm.utils.config_utils import get_model_architectures
99
from lightllm.utils.log_utils import init_logger
10-
from lightllm.utils.envs_utils import get_env_start_args
11-
from lightllm.common.basemodel.basemodel import get_model_init_status
10+
from lightllm.utils.envs_utils import get_env_start_args, get_model_init_status
1211

1312
logger = init_logger(__name__)
1413

@@ -30,25 +29,33 @@ def __init__(
3029
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
3130
self.count = 0
3231
self.scales = None
32+
self.scales_list = None
3333
self.abs_max = None
3434

3535
if is_export_mode:
36-
self.abs_max = torch.zeros((layer_num, 2 * head_num), dtype=torch.float32, device="cuda")
36+
scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2]
37+
self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda")
3738
elif get_env_start_args().kv_quant_calibration_config_path is not None:
3839
logger.info(
3940
f"kv_quant_calibration_config_path {get_env_start_args().kv_quant_calibration_config_path} is set, "
4041
"will load kv quant calibration config"
4142
)
4243
cfg = self._load_and_check_config()
4344

44-
self.scales = torch.tensor(cfg["scales"], dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
45-
if dist.is_initialized() and dist.get_world_size() > 1:
45+
self.scales_list = cfg["scales"]
46+
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
47+
if not get_env_start_args().enable_fa3:
48+
self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1)
49+
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
4650
half_head = self.total_head_num // 2
4751
start_head = dist.get_rank() * head_num
4852
end_head = start_head + head_num
4953
k_scales = self.scales[:, start_head:end_head].contiguous()
5054
v_scales = self.scales[:, start_head + half_head : end_head + half_head].contiguous()
51-
self.scales = torch.cat((k_scales, v_scales), dim=-1)
55+
current_scales = torch.cat((k_scales, v_scales), dim=-1)
56+
57+
self.scales_list = current_scales.tolist()
58+
self.scales = current_scales
5259
else:
5360
logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales")
5461

@@ -74,8 +81,12 @@ def _load_and_check_config(self):
7481
raise ValueError(
7582
f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}"
7683
)
77-
if cfg["quant_type"] != "per_head":
78-
raise ValueError(f"quant type {cfg['quant_type']} in config not match fa3 backend")
84+
if get_env_start_args().enable_fa3:
85+
if cfg["quant_type"] != "per_head":
86+
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
87+
else:
88+
if cfg["quant_type"] != "per_tensor":
89+
raise ValueError(f"quant type {cfg['quant_type']} in config not match flashinfer backend")
7990

8091
return cfg
8192
else:
@@ -93,21 +104,29 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
93104
logger.info("kv cache calibration mode will collect kv cache data for quantization calibration")
94105

95106
if self.abs_max is not None and self.count >= warmup_counts:
96-
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
107+
if get_env_start_args().enable_fa3:
108+
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
109+
else:
110+
k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32)
111+
v_max = kv_buffer[:, self.head_num :, :].abs().amax(dim=()).to(torch.float32)
112+
kv_max = torch.tensor([k_max, v_max], device="cuda", dtype=torch.float32)
97113
self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max)
98114
if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1:
99115
final_abs_max = self.abs_max
100116
if dist.is_initialized() and dist.get_world_size() > 1:
101-
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
102-
k_max = k_max.contiguous()
103-
v_max = v_max.contiguous()
104-
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
105-
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
106-
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
107-
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
108-
k_max = torch.cat(gathered_k_max, dim=-1)
109-
v_max = torch.cat(gathered_v_max, dim=-1)
110-
final_abs_max = torch.cat((k_max, v_max), dim=-1)
117+
if get_env_start_args().enable_fa3:
118+
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
119+
k_max = k_max.contiguous()
120+
v_max = v_max.contiguous()
121+
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
122+
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
123+
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
124+
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
125+
k_max = torch.cat(gathered_k_max, dim=-1)
126+
v_max = torch.cat(gathered_v_max, dim=-1)
127+
final_abs_max = torch.cat((k_max, v_max), dim=-1)
128+
else:
129+
dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False)
111130

112131
self.scales = final_abs_max / self.qmax
113132
self.scales = torch.where(self.scales > 0, self.scales, torch.ones_like(self.scales))
@@ -124,7 +143,7 @@ def _export_calibration_data(self):
124143
cfg = {
125144
"version": "1.0",
126145
"architectures": model_arch,
127-
"quant_type": "per_head",
146+
"quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor",
128147
"qmin": self.qmin,
129148
"qmax": self.qmax,
130149
"num_layers": self.layer_num,

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,24 @@ def _bind_attention(self):
7777
LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self
7878
)
7979
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self)
80-
elif "export_fp8kv_calibration" in self.mode or not self.mode:
80+
elif "export_fp8kv_calibration" in self.mode:
8181
self._context_attention_kernel = partial(
8282
LlamaTransformerLayerInfer._context_attention_flashattention, self
8383
)
8484
self._token_attention_kernel = partial(
8585
LlamaTransformerLayerInfer._token_decode_attention_flashattention, self
8686
)
87-
if "export_fp8kv_calibration" in self.mode:
88-
self._copy_kv_to_mem_cache = partial(
89-
LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self
90-
)
91-
else:
92-
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
87+
self._copy_kv_to_mem_cache = partial(
88+
LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self
89+
)
90+
elif not self.mode:
91+
self._context_attention_kernel = partial(
92+
LlamaTransformerLayerInfer._context_attention_flashattention, self
93+
)
94+
self._token_attention_kernel = partial(
95+
LlamaTransformerLayerInfer._token_decode_attention_flashattention, self
96+
)
97+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
9398
else:
9499
raise Exception(f"Unsupported mode for fa3 backend: {self.mode}")
95100
return
@@ -129,6 +134,15 @@ def _bind_attention(self):
129134
elif "triton_int8kv" in self.mode:
130135
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self)
131136
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self)
137+
elif "offline_calibration_fp8kv" in self.mode:
138+
assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode
139+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self)
140+
self._context_attention_kernel = partial(
141+
LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self
142+
)
143+
self._token_attention_kernel = partial(
144+
LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self
145+
)
132146
elif "triton_flashdecoding" in self.mode:
133147
self._token_attention_kernel = partial(
134148
LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self
@@ -147,6 +161,11 @@ def _bind_attention(self):
147161
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self
148162
)
149163
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
164+
elif "export_fp8kv_calibration" in self.mode:
165+
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self)
166+
self._copy_kv_to_mem_cache = partial(
167+
LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self
168+
)
150169
elif not self.mode:
151170
if get_env_start_args().enable_flashinfer_decode:
152171
self._token_attention_kernel = partial(
@@ -214,6 +233,26 @@ def _tpsp_get_qkv(
214233
)
215234
return q, cache_kv
216235

236+
def _context_attention_flashinfer_kernel_fp8(
237+
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
238+
) -> torch.Tensor:
239+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
240+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
241+
kv = kv.unsqueeze(1)
242+
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
243+
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
244+
offline_scales = infer_state.mem_manager.scales_list
245+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
246+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
247+
infer_state.prefill_wrapper.run(
248+
q.view(q.shape[0], -1, self.head_dim_),
249+
(k, v),
250+
k_scale=k_descale,
251+
v_scale=v_descale,
252+
out=o_tensor.view(q.shape[0], -1, self.head_dim_),
253+
)
254+
return o_tensor
255+
217256
def _context_attention_flashinfer_kernel(
218257
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
219258
) -> torch.Tensor:
@@ -474,6 +513,26 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
474513
)
475514
return
476515

516+
def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None):
517+
batch_size = infer_state.batch_size
518+
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
519+
520+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
521+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1)
522+
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
523+
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
524+
offline_scales = infer_state.mem_manager.scales_list
525+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
526+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
527+
infer_state.decode_wrapper.run(
528+
q.view(calcu_shape1),
529+
(k, v),
530+
k_scale=k_descale,
531+
v_scale=v_descale,
532+
out=o_tensor.view(calcu_shape1),
533+
)
534+
return o_tensor
535+
477536
def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None):
478537
batch_size = infer_state.batch_size
479538
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)

lightllm/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, model):
4141
),
4242
]
4343
self.q_data_type = model.data_type
44-
self.kv_data_type = model.data_type
44+
self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type
4545

4646

4747
@ModelRegistry("llama")

lightllm/server/api_cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
170170
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
171171
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
172172
triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2;
173-
offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 backend,
173+
offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend,
174174
currently only for llama and qwen model;
175175
export_fp8kv_calibration record and export kv cache quant calibration results to a json file.
176-
It can be used for llama and qwen model. Calibration need to disable cudagraph and fa3 backend.
176+
It can be used for llama and qwen model.
177+
Calibration need to disable cudagraph and use fa3 or flashinfer backend.
178+
Tp size must no more than head num when calibration.
177179
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
178180
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
179181
you need to read source code to make sure the supported detail mode for all models""",

lightllm/server/api_start.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,19 @@ def normal_or_p_d_start(args):
111111
assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache"
112112
assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill"
113113
if "offline_calibration_fp8kv" in args.mode:
114-
assert args.enable_fa3 is True, "offline_calibration_fp8kv mode need enable fa3"
114+
assert args.enable_fa3 is True or (
115+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
116+
), (
117+
"offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or "
118+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
119+
)
115120
if "export_fp8kv_calibration" in args.mode:
116-
assert args.enable_fa3 is True, "export_fp8kv_calibration mode need enable fa3"
121+
assert args.enable_fa3 is True or (
122+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
123+
), (
124+
"export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or "
125+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
126+
)
117127
assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph"
118128

119129
# 部分模式还不能支持与高级动态调度算法协同,to do.

lightllm/utils/envs_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,25 @@ def get_redundancy_expert_update_max_load_count():
139139

140140
@lru_cache(maxsize=None)
141141
def get_kv_quant_calibration_warmup_count():
142-
# 服务启动后前warmup次推理不计入量化校准统计,该参数可以控制在一个更大的校准数据集不同位置开始校准
142+
# 服务启动后前warmup次推理不计入量化校准统计,该参数可以控制在一个更大的校准数据集的不同位置处开始校准
143143
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_WARMUP_COUNT", 0))
144144

145145

146146
@lru_cache(maxsize=None)
147147
def get_kv_quant_calibration_inference_count():
148148
# warmup后开始进行量化校准统计,推理次数达到inference_count后输出统计校准结果,通过该参数可以控制对量化校准数据的采集量。
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150+
151+
152+
g_model_init_done = False
153+
154+
155+
def get_model_init_status():
156+
global g_model_init_done
157+
return g_model_init_done
158+
159+
160+
def set_model_init_status(status: bool):
161+
global g_model_init_done
162+
g_model_init_done = status
163+
return g_model_init_done

0 commit comments

Comments
 (0)