Skip to content

Commit b433a93

Browse files
authored
fix the bug for prefilled_step_idx signal of cache_messager in cudagraph and PD (#4235)
1 parent 870364b commit b433a93

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def prefill_layerwise_send_cache_thread(self):
267267
self.cache_info[info["request_id"]] = info
268268
prefilled_layer_idx = layer_shm_value.value[0]
269269
prefilled_step_idx = step_shm_value.value[0]
270-
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
271270
if prefilled_layer_idx == self.num_layers - 1:
272271
time.sleep(0.001)
273272
prefilled_layer_idx = layer_shm_value.value[0]

fastdeploy/worker/worker_process.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,23 @@ def initialize_kv_cache(self) -> None:
442442

443443
def graph_optimize_and_warm_up_model(self) -> None:
444444
self.worker.graph_optimize_and_warm_up_model()
445+
# reset cache_messager prefilled_step signal
446+
if self.scheduler_config.splitwise_role == "prefill":
447+
dp_rank_id = (
448+
self.local_rank
449+
+ self.parallel_config.local_data_parallel_id * self.parallel_config.tensor_parallel_size
450+
)
451+
gpu_id = self.worker.model_runner.device_id
452+
prefilled_step_name = f"splitwise_complete_prefilled_step_{dp_rank_id}"
453+
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
454+
step_shm_value = IPCSignal(
455+
name=prefilled_step_name,
456+
array=prefilled_step_idx_data,
457+
dtype=np.int32,
458+
suffix=gpu_id,
459+
create=False,
460+
)
461+
step_shm_value.value[0] = -1
445462

446463
def init_device(self) -> None:
447464
"""Initialize device and Construct model runner"""
@@ -842,7 +859,7 @@ def run_worker_proc() -> None:
842859
worker_proc.initialize_kv_cache()
843860

844861
# Trigger CUDAGraph capture
845-
worker_proc.worker.graph_optimize_and_warm_up_model()
862+
worker_proc.graph_optimize_and_warm_up_model()
846863

847864
# Initialize health status
848865
worker_proc.init_health_status()

0 commit comments

Comments
 (0)