Skip to content

Commit d435499

Browse files
[Cherry-Pick][Bug Fix]fix the bug for real size 0 in cudagraph (#3888)
* fix the bug for real size 0 in cudagraph * fix cache_messager --------- Co-authored-by: Jiang-Jia-Jun <[email protected]>
1 parent c7c1627 commit d435499

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _prefill_layerwise_send_cache_thread(self):
163163
try:
164164
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
165165
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
166-
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
166+
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
167167
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
168168
step_shm_value = IPCSignal(
169169
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",

fastdeploy/worker/gpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
4343
from fastdeploy.model_executor.model_loader import get_model_loader
4444
from fastdeploy.platforms import current_platform
45+
from fastdeploy.utils import ceil_div
4546

4647
if current_platform.is_iluvatar():
4748
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
@@ -588,17 +589,16 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
588589
"""Set dummy prefill inputs to share_inputs"""
589590
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
590591
max_dec_len = expected_decode_len + 1
591-
full_length = min(
592-
num_tokens // batch_size,
592+
input_length = min(
593+
ceil_div(num_tokens, batch_size),
593594
self.parallel_config.max_model_len - max_dec_len,
594595
)
595596

596597
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
597598
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
598599
if self.fd_config.parallel_config.enable_expert_parallel:
599-
full_length = min(full_length, 32)
600+
input_length = min(input_length, 32)
600601

601-
input_length = int(full_length * self.cache_config.kv_cache_ratio)
602602
block_num = (
603603
input_length + self.cache_config.block_size - 1
604604
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num

0 commit comments

Comments
 (0)