Skip to content

Commit 9020f59

Browse files
author
niushengxiao
committed
feat: add pre tensor fp8 kv quant for flashinfer
1 parent 79b8280 commit 9020f59

11 files changed

+1522
-26
lines changed

lightllm/common/offline_fp8_quant_mem_manager.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,33 @@ def __init__(
3030
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
3131
self.count = 0
3232
self.scales = None
33+
self.scales_list = []
3334
self.abs_max = None
3435

3536
if is_export_mode:
36-
self.abs_max = torch.zeros((layer_num, 2 * head_num), dtype=torch.float32, device="cuda")
37+
scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2]
38+
self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda")
3739
elif get_env_start_args().kv_quant_calibration_config_path is not None:
3840
logger.info(
3941
f"kv_quant_calibration_config_path {get_env_start_args().kv_quant_calibration_config_path} is set, "
4042
"will load kv quant calibration config"
4143
)
4244
cfg = self._load_and_check_config()
4345

44-
self.scales = torch.tensor(cfg["scales"], dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
45-
if dist.is_initialized() and dist.get_world_size() > 1:
46+
self.scales_list = cfg["scales"]
47+
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
48+
if not get_env_start_args().enable_fa3:
49+
self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1)
50+
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
4651
half_head = self.total_head_num // 2
4752
start_head = dist.get_rank() * head_num
4853
end_head = start_head + head_num
4954
k_scales = self.scales[:, start_head:end_head].contiguous()
5055
v_scales = self.scales[:, start_head + half_head : end_head + half_head].contiguous()
51-
self.scales = torch.cat((k_scales, v_scales), dim=-1)
56+
current_scales = torch.cat((k_scales, v_scales), dim=-1)
57+
58+
self.scales_list = current_scales.tolist()
59+
self.scales = current_scales
5260
else:
5361
logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales")
5462

@@ -74,8 +82,12 @@ def _load_and_check_config(self):
7482
raise ValueError(
7583
f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}"
7684
)
77-
if cfg["quant_type"] != "per_head":
78-
raise ValueError(f"quant type {cfg['quant_type']} in config not match fa3 backend")
85+
if get_env_start_args().enable_fa3:
86+
if cfg["quant_type"] != "per_head":
87+
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
88+
else:
89+
if cfg["quant_type"] != "per_tensor":
90+
raise ValueError(f"quant type {cfg['quant_type']} in config not match flashinfer backend")
7991

8092
return cfg
8193
else:
@@ -93,21 +105,29 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
93105
logger.info("kv cache calibration mode will collect kv cache data for quantization calibration")
94106

95107
if self.abs_max is not None and self.count >= warmup_counts:
96-
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
108+
if get_env_start_args().enable_fa3:
109+
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
110+
else:
111+
k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32)
112+
v_max = kv_buffer[:, self.head_num :, :].abs().amax(dim=()).to(torch.float32)
113+
kv_max = torch.tensor([k_max, v_max], device="cuda", dtype=torch.float32)
97114
self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max)
98115
if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1:
99116
final_abs_max = self.abs_max
100117
if dist.is_initialized() and dist.get_world_size() > 1:
101-
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
102-
k_max = k_max.contiguous()
103-
v_max = v_max.contiguous()
104-
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
105-
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
106-
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
107-
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
108-
k_max = torch.cat(gathered_k_max, dim=-1)
109-
v_max = torch.cat(gathered_v_max, dim=-1)
110-
final_abs_max = torch.cat((k_max, v_max), dim=-1)
118+
if get_env_start_args().enable_fa3:
119+
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
120+
k_max = k_max.contiguous()
121+
v_max = v_max.contiguous()
122+
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
123+
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
124+
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
125+
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
126+
k_max = torch.cat(gathered_k_max, dim=-1)
127+
v_max = torch.cat(gathered_v_max, dim=-1)
128+
final_abs_max = torch.cat((k_max, v_max), dim=-1)
129+
else:
130+
dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False)
111131

112132
self.scales = final_abs_max / self.qmax
113133
self.scales = torch.where(self.scales > 0, self.scales, torch.ones_like(self.scales))
@@ -124,7 +144,7 @@ def _export_calibration_data(self):
124144
cfg = {
125145
"version": "1.0",
126146
"architectures": model_arch,
127-
"quant_type": "per_head",
147+
"quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor",
128148
"qmin": self.qmin,
129149
"qmax": self.qmax,
130150
"num_layers": self.layer_num,

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def _bind_attention(self):
129129
elif "triton_int8kv" in self.mode:
130130
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self)
131131
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self)
132+
elif "offline_calibration_fp8kv" in self.mode:
133+
assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode
134+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self)
135+
self._context_attention_kernel = partial(
136+
LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self
137+
)
138+
self._token_attention_kernel = partial(
139+
LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self
140+
)
132141
elif "triton_flashdecoding" in self.mode:
133142
self._token_attention_kernel = partial(
134143
LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self
@@ -147,14 +156,19 @@ def _bind_attention(self):
147156
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self
148157
)
149158
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
150-
elif not self.mode:
159+
elif "export_fp8kv_calibration" in self.mode or not self.mode:
151160
if get_env_start_args().enable_flashinfer_decode:
152161
self._token_attention_kernel = partial(
153162
LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self
154163
)
155164
else:
156165
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
157-
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
166+
if "export_fp8kv_calibration" in self.mode:
167+
self._copy_kv_to_mem_cache = partial(
168+
LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self
169+
)
170+
else:
171+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
158172
else:
159173
raise Exception(f"Unsupported mode: {self.mode}")
160174

@@ -214,6 +228,26 @@ def _tpsp_get_qkv(
214228
)
215229
return q, cache_kv
216230

231+
def _context_attention_flashinfer_kernel_fp8(
232+
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
233+
) -> torch.Tensor:
234+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
235+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
236+
kv = kv.unsqueeze(1)
237+
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
238+
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
239+
offline_scales = infer_state.mem_manager.scales_list
240+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
241+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
242+
infer_state.prefill_wrapper.run(
243+
q.view(q.shape[0], -1, self.head_dim_),
244+
(k, v),
245+
k_scale=k_descale,
246+
v_scale=v_descale,
247+
out=o_tensor.view(q.shape[0], -1, self.head_dim_),
248+
)
249+
return o_tensor
250+
217251
def _context_attention_flashinfer_kernel(
218252
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
219253
) -> torch.Tensor:
@@ -474,6 +508,26 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
474508
)
475509
return
476510

511+
def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None):
512+
batch_size = infer_state.batch_size
513+
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
514+
515+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
516+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1)
517+
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
518+
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
519+
offline_scales = infer_state.mem_manager.scales_list
520+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
521+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
522+
infer_state.decode_wrapper.run(
523+
q.view(calcu_shape1),
524+
(k, v),
525+
k_scale=k_descale,
526+
v_scale=v_descale,
527+
out=o_tensor.view(calcu_shape1),
528+
)
529+
return o_tensor
530+
477531
def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None):
478532
batch_size = infer_state.batch_size
479533
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)

lightllm/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, model):
4141
),
4242
]
4343
self.q_data_type = model.data_type
44-
self.kv_data_type = model.data_type
44+
self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type
4545

4646

4747
@ModelRegistry("llama")

lightllm/server/api_cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
170170
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
171171
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
172172
triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2;
173-
offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 backend,
173+
offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend,
174174
currently only for llama and qwen model;
175175
export_fp8kv_calibration record and export kv cache quant calibration results to a json file.
176-
It can be used for llama and qwen model. Calibration need to disable cudagraph and fa3 backend.
176+
It can be used for llama and qwen model.
177+
Calibration need to disable cudagraph and use fa3 or flashinfer backend.
177178
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
178179
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
179180
you need to read source code to make sure the supported detail mode for all models""",

lightllm/server/api_start.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,19 @@ def normal_or_p_d_start(args):
111111
assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache"
112112
assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill"
113113
if "offline_calibration_fp8kv" in args.mode:
114-
assert args.enable_fa3 is True, "offline_calibration_fp8kv mode need enable fa3"
114+
assert args.enable_fa3 is True or (
115+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
116+
), (
117+
"offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or "
118+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
119+
)
115120
if "export_fp8kv_calibration" in args.mode:
116-
assert args.enable_fa3 is True, "export_fp8kv_calibration mode need enable fa3"
121+
assert args.enable_fa3 is True or (
122+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
123+
), (
124+
"export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or "
125+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
126+
)
117127
assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph"
118128

119129
# 部分模式还不能支持与高级动态调度算法协同,to do.

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_redundancy_expert_update_max_load_count():
139139

140140
@lru_cache(maxsize=None)
141141
def get_kv_quant_calibration_warmup_count():
142-
# 服务启动后前warmup次推理不计入量化校准统计,该参数可以控制在一个更大的校准数据集不同位置开始校准
142+
# 服务启动后前warmup次推理不计入量化校准统计,该参数可以控制在一个更大的校准数据集的不同位置处开始校准
143143
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_WARMUP_COUNT", 0))
144144

145145

0 commit comments

Comments
 (0)