Skip to content

Commit 110f33a

Browse files
rainyflyltd0924
andauthored
[Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node * fix busg * Update engine.py * Update expert_service.py * Update splitwise_connector.py * Optimize log for debug * Optimize log for debug * fix bug --------- Co-authored-by: ltd0924 <[email protected]> Co-authored-by: ltd0924 <[email protected]>
1 parent a4572a5 commit 110f33a

File tree

5 files changed

+144
-57
lines changed

5 files changed

+144
-57
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import argparse
1818
import json
1919
import math
20-
import time
2120
import threading
21+
import time
22+
2223
import numpy as np
2324
import paddle
2425

@@ -196,7 +197,9 @@ def __init__(
196197

197198
self.gpu_id = gpu_id
198199
self.cache_info = dict()
199-
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
200+
self.rank_id = (
201+
self.rank + local_data_parallel_id * self.nranks
202+
) # align with engine worker rank (paddle.distributed.launch)
200203

201204
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
202205
connect_rdma_thread.daemon = True
@@ -284,7 +287,7 @@ def prefill_layerwise_send_cache_thread(self):
284287
if not self.cache_info:
285288
time.sleep(0.001)
286289
continue
287-
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
290+
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
288291
for req_id, item in list(self.cache_info.items()):
289292
if "status" not in item:
290293
continue
@@ -364,7 +367,7 @@ def prefill_layerwise_send_cache_thread(self):
364367

365368
except Exception as e:
366369
logger.info(f"prefill layerwise send cache thread has exception: {e}")
367-
370+
368371
def _handle_connect_task(self):
369372
while True:
370373
try:
@@ -465,7 +468,8 @@ def main():
465468
if __name__ == "__main__":
466469

467470
args = parse_args()
468-
logger = get_logger("cache_messager", "cache_messager.log")
471+
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
472+
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
469473

470474
logger.info("create cache messager...")
471475
logger.info(f"{args}")

fastdeploy/engine/engine.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def __init__(self, cfg):
113113

114114
self.start_queue_service()
115115

116+
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
117+
116118
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
117119
self.resource_manager = ResourceManagerV1(
118120
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
@@ -630,11 +632,15 @@ def receiver_loop():
630632
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
631633
self.insert_tasks([task])
632634
else:
635+
if not self.enable_decode_cache_task:
636+
task.error_msg = "Not enough resources"
633637
new_waiting.append(task)
634-
635638
if new_waiting:
636-
self.waiting_requests.extend(new_waiting)
637-
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
639+
if not self.enable_decode_cache_task:
640+
self.split_connector.send_cache_infos(new_waiting, -1)
641+
else:
642+
self.waiting_requests.extend(new_waiting)
643+
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
638644

639645
else:
640646
time.sleep(0.001)
@@ -805,6 +811,22 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
805811

806812
for task in tasks:
807813
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
814+
if self.cfg.splitwise_role != "mixed":
815+
status, msg = self.split_connector.check_decode_allocated(task)
816+
if not status:
817+
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
818+
self.scheduler.put_results(
819+
[
820+
RequestOutput(
821+
request_id=task.request_id,
822+
finished=True,
823+
error_code=500,
824+
error_msg=msg,
825+
)
826+
]
827+
)
828+
tasks.remove(task)
829+
continue
808830
if task.sampling_params.bad_words is not None:
809831
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
810832

@@ -1020,7 +1042,6 @@ def _exit_sub_services(self):
10201042
except Exception as e:
10211043
print(f"Error extracting sub services: {e}")
10221044

1023-
10241045
for worker_queue in self.engine_worker_queue_server:
10251046
worker_queue.cleanup()
10261047
if hasattr(self, "send_response_server") and self.send_response_server is not None:

fastdeploy/engine/expert_service.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import numpy as np
2828

29+
from fastdeploy.engine.request import RequestOutput
2930
from fastdeploy.engine.resource_manager import ResourceManager
3031
from fastdeploy.inter_communicator import EngineWorkerQueue
3132
from fastdeploy.metrics.metrics import main_process_metrics
@@ -34,6 +35,7 @@
3435
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
3536
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
3637

38+
3739
class ExpertService:
3840
"""
3941
Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -146,7 +148,7 @@ def start(
146148

147149
# Start TokenProcessor thread
148150
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
149-
151+
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
150152
self.token_processor.run()
151153

152154
self.cfg.init_cache_info()
@@ -262,11 +264,15 @@ def receiver_loop():
262264
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
263265
self.insert_tasks([task])
264266
else:
267+
if not self.enable_decode_cache_task:
268+
task.error_msg = "Not enough resources"
265269
new_waiting.append(task)
266-
267270
if new_waiting:
268-
self.waiting_requests.extend(new_waiting)
269-
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
271+
if not self.enable_decode_cache_task:
272+
self.split_connector.send_cache_infos(new_waiting, -1)
273+
else:
274+
self.waiting_requests.extend(new_waiting)
275+
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
270276

271277
else:
272278
time.sleep(0.001)
@@ -310,8 +316,24 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
310316
if not isinstance(tasks, list):
311317
tasks = [tasks]
312318

313-
for item in tasks:
314-
item.schedule_start_time = time.time()
319+
for task in tasks:
320+
if self.cfg.splitwise_role != "mixed":
321+
status, msg = self.split_connector.check_decode_allocated(task)
322+
if not status:
323+
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
324+
self.scheduler.put_results(
325+
[
326+
RequestOutput(
327+
request_id=task.request_id,
328+
finished=True,
329+
error_code=500,
330+
error_msg=msg,
331+
)
332+
]
333+
)
334+
tasks.remove(task)
335+
continue
336+
task.schedule_start_time = time.time()
315337

316338
available_batch = np.sum(self.resource_manager.stop_flags)
317339
if len(tasks) > available_batch:

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
9191
# Whether to use PLUGINS.
9292
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
93+
# Whether to enable cache task in decode node
94+
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
9395
}
9496

9597

0 commit comments

Comments
 (0)