From a14566536d41d6d2a0d6def5844ad737af0a7bf0 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 29 Sep 2025 15:31:54 +0800 Subject: [PATCH] Support MLA_CACHE & Fix V1_Schedule Bug --- .../engine/sched/resource_manager_v1.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 52 ++++++++++++++----- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index d1be17ee3d..10b46c55e2 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -350,7 +350,7 @@ def schedule(self): # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) num_decoding_req_nums += 1 - token_budget -= 1 + token_budget -= 1 else: # need to prefill llm_logger.debug( f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4711684bb7..f508a9e84e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1192,31 +1192,48 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] + + # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + from fastdeploy import envs + + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" - val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + if not self.mla_cache: + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) - val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) - set_data_ipc(val_cache, val_cache_name) - cache_kvs_list.extend([key_cache, val_cache]) + if not self.mla_cache: + val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(val_cache, val_cache_name) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if not self.mla_cache: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) else: logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) - val_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) - val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) - cache_kvs_list.extend([key_cache, val_cache]) + if not self.mla_cache: + val_cache = paddle.empty(shape=[], dtype=cache_type) + val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) + self.share_inputs["caches"] = cache_kvs_list if not profile and create_cache_tensor: @@ -1936,7 +1953,18 @@ def cal_theortical_kvcache(self): if self.speculative_method in ["mtp"] else self.model_config.num_hidden_layers ) - required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + + # NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + if self.mla_cache: + required_memory = ( + byte_of_dtype + * (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim) + * (self.cache_config.block_size) + * num_layers + ) # compress_kv + k_pe + else: + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v return required_memory def not_need_stop(self) -> bool: