diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 44c62b89a3..20e908817f 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -460,8 +460,7 @@ def schedule(self): # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) num_decoding_req_nums += 1 - token_budget -= 1 - + token_budget -= 1 if ( request.use_extend_tables and request.request_id not in self.using_extend_tables_req_id diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4711684bb7..f508a9e84e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1192,31 +1192,48 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] + + # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + from fastdeploy import envs + + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" - val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + if not self.mla_cache: + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) - val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) - set_data_ipc(val_cache, val_cache_name) - cache_kvs_list.extend([key_cache, val_cache]) + if not self.mla_cache: + val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(val_cache, val_cache_name) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if not self.mla_cache: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) else: logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) - val_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) - val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) - cache_kvs_list.extend([key_cache, val_cache]) + if not self.mla_cache: + val_cache = paddle.empty(shape=[], dtype=cache_type) + val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) + self.share_inputs["caches"] = cache_kvs_list if not profile and create_cache_tensor: @@ -1936,7 +1953,18 @@ def cal_theortical_kvcache(self): if self.speculative_method in ["mtp"] else self.model_config.num_hidden_layers ) - required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + + # NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + if self.mla_cache: + required_memory = ( + byte_of_dtype + * (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim) + * (self.cache_config.block_size) + * num_layers + ) # compress_kv + k_pe + else: + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v return required_memory def not_need_stop(self) -> bool: