Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,14 @@ def launch_components(self):

self.cfg.init_cache_info()

request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
Expand Down
33 changes: 32 additions & 1 deletion fastdeploy/scheduler/splitwise_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def __init__(
self.writer_parallel = writer_parallel
self.writer_batch_size = writer_batch_size

self.max_model_len = kwargs.get("max_model_len", 8192)
self.enable_chunked_prefill = kwargs.get("enable_chunked_prefill", False)
self.max_num_partial_prefills = kwargs.get("max_num_partial_prefills", 1)
self.max_long_partial_prefills = kwargs.get("max_long_partial_prefills", 1)
self.long_prefill_token_threshold = kwargs.get("long_prefill_token_threshold", 0)
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)

def check(self):
"""check argument"""
pass
Expand Down Expand Up @@ -674,6 +682,7 @@ class InferScheduler:
"""

def __init__(self, config):
self.config = config
self.nodeid = config.nodeid
self.writer_parallel = config.writer_parallel
self.writer_batch_size = config.writer_batch_size
Expand Down Expand Up @@ -792,9 +801,13 @@ def get_requests(
reqs = []
required_blocks = 0
current_prefill_tokens = 0
long_partial_requests, short_partial_requests = 0, 0
cur_time = time.time()
for i in range(batch):
try:
if len(self.reqs_queue) == 0:
break

req = self.reqs_queue.popleft()
if cur_time - req.arrival_time > self.ttl:
logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests")
Expand All @@ -803,9 +816,27 @@ def get_requests(
current_prefill_tokens += req.prompt_token_ids_len
required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size
required_blocks += required_input_blocks + reserved_output_blocks
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
if required_blocks > available_blocks:
self.reqs_queue.appendleft(req)
return reqs

if self.config.enable_chunked_prefill:
if req.prompt_token_ids_len > self.config.long_prefill_token_threshold:
# long partial requests
long_partial_requests += 1
if long_partial_requests > self.config.max_long_partial_prefills:
self.reqs_queue.appendleft(req)
break
else:
short_partial_requests += 1

if short_partial_requests + long_partial_requests > self.config.max_num_partial_prefills:
self.reqs_queue.appendleft(req)
break
else:
if current_prefill_tokens > max_num_batched_tokens:
self.reqs_queue.appendleft(req)
break
# logger.info(f"Get Requests from Scheduler: {req.request_id}")
reqs.append(req)
except Exception as e:
Expand Down
Loading