Skip to content

Commit 6c5a5db

Browse files
author
niushengxiao
committed
feat: change mode name to offline_calibration_fp8kv
1 parent aa6acb3 commit 6c5a5db

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

lightllm/common/mem_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def select_mem_manager_class(mode):
2222
logger.info("Model kv cache using mode triton int8kv")
2323
elif "triton_fp8kv" in mode:
2424
raise Exception("currently only for deepseek")
25-
elif "calibration_fp8kv" in mode:
25+
elif "offline_calibration_fp8kv" in mode:
2626
memory_manager_class = CalibrationFP8KVMemoryManager
2727
logger.info("Model kv cache using mode calibration fp8kv")
2828
elif "export_fp8kv_calibration" in mode:

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3232
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
3333
)
3434
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
35-
if "calibration_fp8kv" in model.mode:
35+
if "offline_calibration_fp8kv" in model.mode:
3636
device = input_ids.device
3737
self.q_scale = torch.empty(
3838
(self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device
@@ -60,7 +60,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
6060
)
6161
self.page_table[:, max_seq_len_k:].fill_(0)
6262

63-
if "calibration_fp8kv" in model.mode:
63+
if "offline_calibration_fp8kv" in model.mode:
6464
offline_scales = self.mem_manager.scales
6565
head_num = self.mem_manager.head_num
6666
self.k_descale = (

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ def _bind_norm(self):
6969

7070
def _bind_attention(self):
7171
if get_env_start_args().enable_fa3:
72-
if "calibration_fp8kv" in self.mode:
72+
if "offline_calibration_fp8kv" in self.mode:
7373
self._context_attention_kernel = partial(
7474
LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self
7575
)
7676
self._token_attention_kernel = partial(
7777
LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self
7878
)
7979
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self)
80-
else:
80+
elif not self.mode:
8181
self._context_attention_kernel = partial(
8282
LlamaTransformerLayerInfer._context_attention_flashattention, self
8383
)
@@ -90,6 +90,8 @@ def _bind_attention(self):
9090
)
9191
else:
9292
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
93+
else:
94+
raise Exception(f"Unsupported mode for fa3 backend: {self.mode}")
9395
return
9496
elif get_env_start_args().enable_flashinfer_prefill:
9597
self._context_attention_kernel = partial(
@@ -127,7 +129,7 @@ def _bind_attention(self):
127129
elif "triton_int8kv" in self.mode:
128130
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self)
129131
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self)
130-
elif "calibration_fp8kv" in self.mode:
132+
elif "offline_calibration_fp8kv" in self.mode:
131133
raise Exception("calibration fp8 kvcache only support fa3 backend")
132134
elif "triton_flashdecoding" in self.mode:
133135
self._token_attention_kernel = partial(
@@ -147,14 +149,16 @@ def _bind_attention(self):
147149
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self
148150
)
149151
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
150-
else:
152+
elif not self.mode:
151153
if get_env_start_args().enable_flashinfer_decode:
152154
self._token_attention_kernel = partial(
153155
LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self
154156
)
155157
else:
156158
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
157159
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
160+
else:
161+
raise Exception(f"Unsupported mode: {self.mode}")
158162

159163
return
160164

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
164164
default=[],
165165
nargs="+",
166166
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
167-
| triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | calibration_fp8kv
167+
| triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv
168168
| export_fp8kv_calibration
169169
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
170170
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
171171
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
172172
triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2;
173-
calibration_fp8kv mode use float8 to store kv cache, need fa3 backend,
173+
offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 backend,
174174
currently only for llama and qwen model;
175175
export_fp8kv_calibration record and export kv cache quant calibration results to a json file.
176176
It can be used for llama and qwen model. Calibration need to disable cudagraph and fa3 backend.

0 commit comments

Comments
 (0)