Skip to content

Commit c1ebd6d

Browse files
committed
feat(vllm): asynchronous call support
1 parent 0db3171 commit c1ebd6d

File tree

1 file changed

+115
-44
lines changed

1 file changed

+115
-44
lines changed

mellea/backends/vllm.py

Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from __future__ import annotations
77

88
import abc
9+
import asyncio
910
import dataclasses
1011
import datetime
12+
import functools
13+
import importlib
1114
import inspect
1215
import json
1316
import os
@@ -20,7 +23,7 @@
2023
import outlines_core
2124
import torch
2225
import vllm # type:ignore
23-
from transformers import PreTrainedTokenizerBase
26+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2427

2528
from mellea.backends import BaseModelSubclass
2629
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
@@ -32,12 +35,14 @@
3235
)
3336
from mellea.backends.types import ModelOption
3437
from mellea.backends.utils import extract_model_tool_requests, to_chat
38+
from mellea.helpers.async_helpers import send_to_queue
3539
from mellea.helpers.fancy_logger import FancyLogger
3640
from mellea.stdlib.base import (
3741
CBlock,
3842
Component,
3943
Context,
4044
GenerateLog,
45+
GenerateType,
4146
ModelOutputThunk,
4247
TemplateRepresentation,
4348
)
@@ -114,7 +119,7 @@ def __init__(
114119
# vllm requires some model options during instantiation.
115120
engine_args = self._simplify_and_merge(model_options)
116121
engine_args = self._make_backend_specific_and_remove(
117-
engine_args, vllm.EngineArgs
122+
engine_args, vllm.AsyncEngineArgs
118123
)
119124

120125
logger = FancyLogger.get_logger()
@@ -135,8 +140,8 @@ def __init__(
135140
while True:
136141
retry += 1
137142
try:
138-
self._model = vllm.LLM(
139-
model=self._hf_model_id, task="generate", **engine_args
143+
self._model = vllm.AsyncLLMEngine.from_engine_args(
144+
vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args)
140145
)
141146
break
142147
except torch._dynamo.exc.BackendCompilerFailed as e:
@@ -187,12 +192,18 @@ def __init__(
187192
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
188193
)
189194

190-
self._tokenizer: PreTrainedTokenizerBase = self._model.get_tokenizer() # type:ignore
195+
self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
196+
self._hf_model_id
197+
) # type:ignore
191198

192-
# see notes in outlines.models.vllm.adapt_tokenizer
193-
self._tokenizer_for_outlines: PreTrainedTokenizerBase = outlines.models.VLLM(
194-
self._model
195-
).tokenizer # type:ignore
199+
# See the notes in outlines.models.vllm.adapt_tokenizer for why this is needed.
200+
# Note: there is a module named outlines.models.vllm and a function named outlines.models.vllm.vllm .
201+
# However, outlines.models import outlines.models.vllm.vllm as vllm,
202+
# thus the module outlines.models.vllm becomes inaccessible,
203+
# hence the use of importlib to get the module.
204+
self._tokenizer_for_outlines: PreTrainedTokenizerBase = importlib.import_module(
205+
"outlines.models.vllm"
206+
).adapt_tokenizer(self._tokenizer)
196207

197208
def generate_from_context(
198209
self,
@@ -232,8 +243,6 @@ def _generate_from_context_standard(
232243
# Construct input.
233244
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
234245
# Otherwise, we will linearize the context and treat it as a raw input.
235-
decoded_result: str | None = None
236-
237246
if ctx.is_chat_context:
238247
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
239248
ctx_as_chat = to_chat(action, ctx, self.formatter, system_prompt)
@@ -265,7 +274,8 @@ def _generate_from_context_standard(
265274
sampling_params = vllm.SamplingParams(
266275
**self._make_backend_specific_and_remove(
267276
model_options, vllm.SamplingParams
268-
)
277+
),
278+
output_kind=vllm.sampling_params.RequestOutputKind.DELTA, # returns results incrementally
269279
)
270280

271281
if format is not None:
@@ -287,44 +297,95 @@ def _generate_from_context_standard(
287297
[logits_processor] if logits_processor is not None else []
288298
)
289299

290-
ros: list[vllm.RequestOutput] = self._model.generate( # type: ignore
291-
[input_str], sampling_params=sampling_params
300+
# stream = model_options.get(ModelOption.STREAM, False)
301+
# if stream:
302+
303+
output = ModelOutputThunk(None)
304+
305+
generator = self._model.generate( # type: ignore
306+
request_id=str(id(output)),
307+
prompt=input_str,
308+
sampling_params=sampling_params,
292309
) # type: ignore
293310

294-
decoded_results = [ro.outputs[0].text for ro in ros]
311+
output._context = ctx.render_for_generation()
312+
output._action = action
313+
output._model_options = model_options
314+
315+
output._process = self.processing
316+
output._post_process = functools.partial(
317+
self.post_processing,
318+
conversation=ctx_as_chat,
319+
tool_calls=tool_calls,
320+
tools=tools,
321+
seed=model_options.get(ModelOption.SEED, None),
322+
)
323+
324+
try:
325+
# This function should always be called from a running event loop so we don't have to worry about
326+
# scheduling the task to a specific event loop here.
327+
output._generate = asyncio.create_task(
328+
send_to_queue(generator, output._async_queue) # type: ignore
329+
)
330+
output._generate_type = GenerateType.ASYNC
331+
except RuntimeError as e:
332+
# Most likely cause is running this function without an event loop present.
333+
raise e
295334

296-
decoded_result = decoded_results[0]
335+
return output
297336

298337
else:
299338
raise Exception("Does not yet support non-chat contexts.")
300339

301-
assert decoded_result is not None
340+
async def processing(self, mot: ModelOutputThunk, chunk: vllm.RequestOutput):
341+
"""Process the returned chunks or the complete response."""
342+
if mot._underlying_value is None:
343+
mot._underlying_value = ""
344+
mot._underlying_value += chunk.outputs[0].text
345+
346+
async def post_processing(
347+
self,
348+
mot: ModelOutputThunk,
349+
conversation: list[dict],
350+
tool_calls: bool,
351+
tools: dict[str, Callable],
352+
seed,
353+
):
354+
"""Called when generation is done."""
302355

303-
result = ModelOutputThunk(value=decoded_result)
356+
# The ModelOutputThunk must be computed by this point.
357+
assert mot.value is not None
304358

305-
# Only scan for tools if we are not doing structured decoding and tool calls were provided to the model.
359+
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
306360
if format is None and tool_calls:
307-
result.tool_calls = extract_model_tool_requests(tools, decoded_result)
361+
mot.tool_calls = extract_model_tool_requests(tools, mot.value)
308362

309-
parsed_result = self.formatter.parse(action, result)
310-
if generate_logs is not None:
311-
assert isinstance(generate_logs, list)
312-
generate_log = GenerateLog()
313-
generate_log.prompt = ctx_as_chat
314-
generate_log.backend = f"vllm::{self.model_id!s}"
315-
generate_log.model_options = model_options
316-
generate_log.date = datetime.datetime.now()
317-
generate_log.model_output = decoded_result
318-
generate_log.extra = {
319-
"format": format,
320-
"tools_available": tools,
321-
"tools_called": result.tool_calls,
322-
"seed": model_options.get(ModelOption.SEED, None),
323-
}
324-
generate_log.action = action
325-
generate_log.result = parsed_result
326-
generate_logs.append(generate_log)
327-
return parsed_result
363+
assert mot._action is not None, (
364+
"ModelOutputThunks should have their action assigned during generation"
365+
)
366+
assert mot._model_options is not None, (
367+
"ModelOutputThunks should have their model_opts assigned during generation"
368+
)
369+
370+
self.formatter.parse(mot._action, mot)
371+
372+
# Generate the log for this ModelOutputThunk.
373+
generate_log = GenerateLog()
374+
generate_log.prompt = conversation
375+
generate_log.backend = f"vllm::{self.model_id!s}"
376+
generate_log.model_options = mot._model_options
377+
generate_log.date = datetime.datetime.now()
378+
generate_log.model_output = mot.value
379+
generate_log.extra = {
380+
"format": format,
381+
"tools_available": tools,
382+
"tools_called": mot.tool_calls,
383+
"seed": seed,
384+
}
385+
generate_log.action = mot._action
386+
generate_log.result = mot
387+
388+
mot._generate_log = generate_log
328389

329390
def _generate_from_raw(
330391
self,
@@ -340,7 +401,10 @@ def _generate_from_raw(
340401
prompts = [self.formatter.print(action) for action in actions]
341402

342403
sampling_params = vllm.SamplingParams(
343-
**self._make_backend_specific_and_remove(model_options, vllm.SamplingParams)
404+
**self._make_backend_specific_and_remove(
405+
model_options, vllm.SamplingParams
406+
),
407+
output_kind=vllm.sampling_params.RequestOutputKind.FINAL_ONLY, # returns only the final results
344408
)
345409

346410
if format is not None:
@@ -360,11 +424,18 @@ def _generate_from_raw(
360424
[logits_processor] if logits_processor is not None else []
361425
)
362426

363-
ros: list[vllm.RequestOutput] = self._model.generate( # type: ignore
364-
prompts, sampling_params=sampling_params
365-
) # type: ignore
427+
async def generate(prompt, request_id):
428+
async for result_output in self._model.generate(
429+
request_id=request_id, prompt=prompt, sampling_params=sampling_params
430+
):
431+
assert result_output.finished
432+
return result_output.outputs[0].text
433+
434+
async def generate_all(prompts):
435+
tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)]
436+
return await asyncio.gather(*tasks)
366437

367-
decoded_results = [ro.outputs[0].text for ro in ros]
438+
decoded_results = asyncio.run(generate_all(prompts))
368439

369440
results = [ModelOutputThunk(value=text) for text in decoded_results]
370441

0 commit comments

Comments
 (0)