Skip to content

Commit fee37b6

Browse files
dcaoxdominicshanshan
authored andcommitted
[None][feat] Move StreamGeneration to scaffolding main directory (NVIDIA#8347)
Signed-off-by: Dong Cao <[email protected]>
1 parent 5e72651 commit fee37b6

File tree

7 files changed

+109
-24
lines changed

7 files changed

+109
-24
lines changed

examples/scaffolding/contrib/AsyncGeneration/stream_generation_controller.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ def process(self, tasks: List[Task], **kwargs):
4242
"custom_sampling_params")
4343
elif self.custom_sampling_params:
4444
task.custom_sampling_params = self.custom_sampling_params
45-
stream_task = StreamGenerationTask()
46-
stream_task.__dict__ = copy.deepcopy(task.__dict__)
47-
stream_task.streaming_step = self.stream_step
45+
stream_task = StreamGenerationTask.create_from_generation_task(
46+
task, self.stream_step)
4847
stream_tasks.append(stream_task)
4948
lst = list(range(len(stream_tasks)))
5049

tensorrt_llm/scaffolding/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
77
get_digit_majority_vote_result)
88
from .scaffolding_llm import ScaffoldingLlm
9-
from .task import GenerationTask, RewardTask, Task, TaskStatus
9+
from .task import (GenerationTask, RewardTask, StreamGenerationTask, Task,
10+
TaskStatus)
1011
from .task_collection import (GenerationTokenCounter, TaskCollection,
1112
with_task_collection)
1213
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
@@ -22,6 +23,7 @@
2223
"BestOfNController",
2324
"Task",
2425
"GenerationTask",
26+
"StreamGenerationTask",
2527
"RewardTask",
2628
"Worker",
2729
"OpenaiWorker",

tensorrt_llm/scaffolding/contrib/AsyncGeneration/stream_generation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import copy
23
from dataclasses import dataclass, field
34
from typing import Any, Optional
45

@@ -22,6 +23,15 @@ class StreamGenerationTask(GenerationTask):
2223
# worker set this field to True when the generation is finished
2324
end_flag: bool = field(default=False)
2425

26+
@staticmethod
27+
def create_from_generation_task(
28+
task: GenerationTask,
29+
streaming_step: int) -> "StreamGenerationTask":
30+
stream_task = StreamGenerationTask()
31+
stream_task.__dict__ = copy.deepcopy(task.__dict__)
32+
stream_task.streaming_step = streaming_step
33+
return stream_task
34+
2535

2636
async def stream_generation_handler(worker,
2737
task: StreamGenerationTask) -> TaskStatus:

tensorrt_llm/scaffolding/controller.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,16 @@ def process(self,
230230
yield ParallelProcess(generation_controllers, tasks_list,
231231
generation_kwargs_list)
232232

233-
candidates = [tasks[0].output_str for tasks in tasks_list]
234233
majority_index, majority_answer = self.majority_vote(
235-
candidates, **majority_vote_kwargs)
234+
tasks_list, **majority_vote_kwargs)
236235

237236
assert isinstance(majority_answer, str), "majority_vote failed"
238237
# The task returned by majority vote does not have output_tokens and logits.
239238
tasks[0].result = tasks_list[majority_index][0].result
240239

241-
def majority_vote(self, candidates: List[str], **kwargs) -> Tuple[int, str]:
240+
def majority_vote(self, candidates_tasks: List[List[Task]],
241+
**kwargs) -> Tuple[int, str]:
242+
candidates = [tasks[0].output_str for tasks in candidates_tasks]
242243
return get_digit_majority_vote_result(candidates)
243244

244245

tensorrt_llm/scaffolding/scaffolding_llm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,19 @@ def generate_async(self, prompt: str) -> ScaffoldingResult:
175175
result = ScaffoldingResult(self.streaming_event)
176176

177177
async def put_request():
178-
request = ScaffoldingRequest(
179-
prompt=prompt,
180-
kwargs={},
181-
result=result,
182-
controller=self.prototype_controller.clone())
183-
184-
await self.task_queue.put(request)
178+
try:
179+
request = ScaffoldingRequest(
180+
prompt=prompt,
181+
kwargs={},
182+
result=result,
183+
controller=self.prototype_controller.clone())
184+
except Exception as e:
185+
self.task_queue.put(None)
186+
print(
187+
f"Error: build ScaffoldingRequest failed: {e} \n {traceback.format_exc()}"
188+
)
189+
else:
190+
await self.task_queue.put(request)
185191

186192
asyncio.run_coroutine_threadsafe(put_request(), self.loop)
187193

@@ -208,7 +214,7 @@ def enable_output_task_collection(self):
208214

209215
def shutdown(self, shutdown_workers=False):
210216

211-
def shutdown_workers():
217+
def shutdown_workers_func():
212218
for worker in self.workers.values():
213219
worker.shutdown()
214220

@@ -228,4 +234,4 @@ async def stop_task_on_loop():
228234
self.shutdown_event.set()
229235

230236
if shutdown_workers:
231-
shutdown_workers()
237+
shutdown_workers_func()

tensorrt_llm/scaffolding/task.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import copy
12
from dataclasses import dataclass, field
23
from enum import Enum
3-
from typing import Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Union
45

56
import torch
67

7-
from tensorrt_llm.executor.result import GenerationResult
8+
from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs
89

910

1011
@dataclass
@@ -64,6 +65,7 @@ class GenerationTask(Task):
6465
# result field
6566
# link to TRTLLM's GenerationResult, for async update in streaming mode
6667
_result: Optional[GenerationResult] = None
68+
customized_result_fields: Dict[str, Any] = field(default_factory=dict)
6769

6870
@property
6971
def result(self) -> GenerationResult:
@@ -96,7 +98,7 @@ def cumulative_logprob(self) -> Optional[float]:
9698
0].cumulative_logprob if self._result else None
9799

98100
@property
99-
def logprobs(self) -> Optional[List[float]]:
101+
def logprobs(self) -> Optional[TokenLogprobs]:
100102
return self._result.outputs[0].logprobs if self._result else None
101103

102104
@property
@@ -115,6 +117,32 @@ def create_scaffolding_output(self) -> GenerationResult:
115117
return self._result
116118

117119

120+
@dataclass
121+
class StreamGenerationTask(GenerationTask):
122+
# input field
123+
# if the flag is set to True, the worker will cancel the generation work
124+
cancel_flag: Optional[bool] = field(default=False)
125+
# the task will be returned to the controller with at least new streaming_step tokens
126+
# if the streaming_step is set to 0,
127+
# the task will be returned to the controller immediately with
128+
# new tokens that have already been generated.
129+
streaming_step: Optional[int] = field(default=1)
130+
131+
#result field
132+
# worker set this field and identify the same task by this field
133+
request_handle: Any = field(default=None)
134+
# worker set this field to True when the generation is finished
135+
end_flag: bool = field(default=False)
136+
137+
@staticmethod
138+
def create_from_generation_task(task: GenerationTask,
139+
streaming_step) -> "StreamGenerationTask":
140+
stream_task = StreamGenerationTask()
141+
stream_task.__dict__ = copy.deepcopy(task.__dict__)
142+
stream_task.streaming_step = streaming_step
143+
return stream_task
144+
145+
118146
@dataclass
119147
class RewardTask(Task):
120148
# input field

tensorrt_llm/scaffolding/worker.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import asyncio
12
from abc import ABC
2-
from typing import Callable
3+
from typing import Callable, Optional
34

45
import openai
56
from transformers import AutoTokenizer
67

78
from tensorrt_llm import LLM
89
from tensorrt_llm.executor import GenerationExecutor
9-
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
10+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig
1011
from tensorrt_llm.sampling_params import SamplingParams
1112

12-
from .task import GenerationTask, Task, TaskStatus
13+
from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus
1314

1415
ExecutorCls = GenerationExecutor
1516

@@ -150,6 +151,7 @@ def init_with_new_llm(
150151
max_num_tokens: int = 4096,
151152
kv_cache_free_gpu_memory_fraction: float = 0.9,
152153
disable_overlap_scheduler: bool = False,
154+
scheduler_config: Optional[SchedulerConfig] = None,
153155
):
154156
kv_cache_config = KvCacheConfig(
155157
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
@@ -168,7 +170,8 @@ def init_with_new_llm(
168170
disable_overlap_scheduler=disable_overlap_scheduler,
169171
kv_cache_config=kv_cache_config,
170172
max_batch_size=max_batch_size,
171-
max_num_tokens=max_num_tokens)
173+
max_num_tokens=max_num_tokens,
174+
scheduler_config=scheduler_config)
172175

173176
worker = cls(llm, tokenizer)
174177
worker.own_llm = True
@@ -201,8 +204,44 @@ async def generation_handler(self, task: GenerationTask) -> TaskStatus:
201204
# TODO: error handle
202205
return TaskStatus.SUCCESS
203206

207+
async def stream_generation_handler(
208+
self, task: StreamGenerationTask) -> TaskStatus:
209+
210+
async def get_step_or_more_tokens(task: StreamGenerationTask):
211+
if task.cancel_flag:
212+
task.end_flag = True
213+
task.request_handle.abort()
214+
return TaskStatus.SUCCESS
215+
216+
for _ in range(task.streaming_step):
217+
await task.request_handle._aresult_step()
218+
if task.request_handle._done:
219+
break
220+
221+
while not task.request_handle._done:
222+
async_task = asyncio.create_task(
223+
task.request_handle._aresult_step())
224+
if not async_task.done():
225+
async_task.cancel()
226+
break
227+
228+
if task.request_handle._done:
229+
task.end_flag = True
230+
231+
if getattr(task, 'end_flag', False):
232+
return TaskStatus.SUCCESS
233+
if task.request_handle is None:
234+
sampling_params = self.convert_task_params(task)
235+
task.request_handle = self.llm.generate_async(
236+
task.input_str, sampling_params=sampling_params, streaming=True)
237+
task._result = task.request_handle
238+
await get_step_or_more_tokens(task)
239+
204240
def shutdown(self):
205241
if self.own_llm:
206242
self.llm.shutdown()
207243

208-
task_handlers = {GenerationTask: generation_handler}
244+
task_handlers = {
245+
GenerationTask: generation_handler,
246+
StreamGenerationTask: stream_generation_handler
247+
}

0 commit comments

Comments
 (0)