|
25 | 25 | from vllm.entrypoints.utils import _validate_truncation_size |
26 | 26 | from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs |
27 | 27 | from vllm.lora.request import LoRARequest |
28 | | -from vllm.outputs import RequestOutput |
| 28 | +from vllm.outputs import CompletionOutput, RequestOutput |
29 | 29 | from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams |
30 | 30 | from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs |
31 | 31 | from vllm.usage.usage_lib import UsageContext |
|
44 | 44 | from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh |
45 | 45 |
|
46 | 46 | from forge.data.sharding import VLLMSharding |
| 47 | +from forge.data_models.completion import Completion |
| 48 | +from forge.data_models.prompt import to_prompt |
| 49 | + |
47 | 50 | from forge.interfaces import Policy as PolicyInterface |
48 | 51 | from forge.types import ProcessConfig |
49 | 52 |
|
@@ -258,7 +261,7 @@ def start_processing(self): |
258 | 261 | self._run_task = asyncio.create_task(self.run()) |
259 | 262 |
|
260 | 263 | @endpoint |
261 | | - async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: |
| 264 | + async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: |
262 | 265 | """Generate a response for the given prompt |
263 | 266 |
|
264 | 267 | Args: |
@@ -362,8 +365,9 @@ async def run(self): |
362 | 365 |
|
363 | 366 | for request_output in processed_outputs.request_outputs: |
364 | 367 | if request_output.finished: |
| 368 | + completions = self._to_completions(request_output) |
365 | 369 | _, fut = self.requests.pop(request_output.request_id) |
366 | | - fut.set_result(request_output) |
| 370 | + fut.set_result(completions) |
367 | 371 |
|
368 | 372 | @endpoint |
369 | 373 | async def update_weights(self, policy_version: int): |
@@ -396,6 +400,42 @@ async def get_version(self) -> int: |
396 | 400 | async def stop(self): |
397 | 401 | self.running = False |
398 | 402 |
|
| 403 | + def _to_completions(self, request_output: RequestOutput) -> list[Completion]: |
| 404 | + """Convert a RequestOutput to a list of Completion objects.""" |
| 405 | + completions = [] |
| 406 | + original_prompt = request_output.prompt |
| 407 | + prompt_token_ids = request_output.prompt_token_ids |
| 408 | + for output in request_output.outputs: |
| 409 | + completions.append( |
| 410 | + Completion( |
| 411 | + # TODO: the to_prompt encoding will be different from the original. |
| 412 | + # This is okay for now, since I don't see any direct usage of prompt using completion object. |
| 413 | + prompt=to_prompt(original_prompt), |
| 414 | + stop_reason=output.finish_reason, |
| 415 | + text=output.text, |
| 416 | + prompt_ids=torch.tensor(prompt_token_ids), |
| 417 | + token_ids=torch.tensor(output.token_ids), |
| 418 | + logprobs=self._extract_logprobs(output), |
| 419 | + ) |
| 420 | + ) |
| 421 | + |
| 422 | + return completions |
| 423 | + |
| 424 | + def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None: |
| 425 | + """ |
| 426 | + Extract log probabilities from a sample, if available. |
| 427 | + """ |
| 428 | + if one_sample.logprobs is not None: |
| 429 | + return torch.tensor( |
| 430 | + [ |
| 431 | + top_k_dict[token].logprob |
| 432 | + for token, top_k_dict in zip( |
| 433 | + one_sample.token_ids, one_sample.logprobs |
| 434 | + ) |
| 435 | + ] |
| 436 | + ) |
| 437 | + return None |
| 438 | + |
399 | 439 |
|
400 | 440 | @dataclass |
401 | 441 | class PolicyWorker(ForgeActor): |
|
0 commit comments