1+ import asyncio
12from abc import ABC
2- from typing import Callable
3+ from typing import Callable , Optional
34
45import openai
56from transformers import AutoTokenizer
67
78from tensorrt_llm import LLM
89from tensorrt_llm .executor import GenerationExecutor
9- from tensorrt_llm .llmapi .llm_args import KvCacheConfig
10+ from tensorrt_llm .llmapi .llm_args import KvCacheConfig , SchedulerConfig
1011from tensorrt_llm .sampling_params import SamplingParams
1112
12- from .task import GenerationTask , Task , TaskStatus
13+ from .task import GenerationTask , StreamGenerationTask , Task , TaskStatus
1314
1415ExecutorCls = GenerationExecutor
1516
@@ -150,6 +151,7 @@ def init_with_new_llm(
150151 max_num_tokens : int = 4096 ,
151152 kv_cache_free_gpu_memory_fraction : float = 0.9 ,
152153 disable_overlap_scheduler : bool = False ,
154+ scheduler_config : Optional [SchedulerConfig ] = None ,
153155 ):
154156 kv_cache_config = KvCacheConfig (
155157 free_gpu_memory_fraction = kv_cache_free_gpu_memory_fraction , )
@@ -168,7 +170,8 @@ def init_with_new_llm(
168170 disable_overlap_scheduler = disable_overlap_scheduler ,
169171 kv_cache_config = kv_cache_config ,
170172 max_batch_size = max_batch_size ,
171- max_num_tokens = max_num_tokens )
173+ max_num_tokens = max_num_tokens ,
174+ scheduler_config = scheduler_config )
172175
173176 worker = cls (llm , tokenizer )
174177 worker .own_llm = True
@@ -201,8 +204,44 @@ async def generation_handler(self, task: GenerationTask) -> TaskStatus:
201204 # TODO: error handle
202205 return TaskStatus .SUCCESS
203206
207+ async def stream_generation_handler (
208+ self , task : StreamGenerationTask ) -> TaskStatus :
209+
210+ async def get_step_or_more_tokens (task : StreamGenerationTask ):
211+ if task .cancel_flag :
212+ task .end_flag = True
213+ task .request_handle .abort ()
214+ return TaskStatus .SUCCESS
215+
216+ for _ in range (task .streaming_step ):
217+ await task .request_handle ._aresult_step ()
218+ if task .request_handle ._done :
219+ break
220+
221+ while not task .request_handle ._done :
222+ async_task = asyncio .create_task (
223+ task .request_handle ._aresult_step ())
224+ if not async_task .done ():
225+ async_task .cancel ()
226+ break
227+
228+ if task .request_handle ._done :
229+ task .end_flag = True
230+
231+ if getattr (task , 'end_flag' , False ):
232+ return TaskStatus .SUCCESS
233+ if task .request_handle is None :
234+ sampling_params = self .convert_task_params (task )
235+ task .request_handle = self .llm .generate_async (
236+ task .input_str , sampling_params = sampling_params , streaming = True )
237+ task ._result = task .request_handle
238+ await get_step_or_more_tokens (task )
239+
204240 def shutdown (self ):
205241 if self .own_llm :
206242 self .llm .shutdown ()
207243
208- task_handlers = {GenerationTask : generation_handler }
244+ task_handlers = {
245+ GenerationTask : generation_handler ,
246+ StreamGenerationTask : stream_generation_handler
247+ }
0 commit comments