Skip to content
29 changes: 26 additions & 3 deletions mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
convert_tools_to_json,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue
from mellea.helpers.event_loop_helper import _run_async_in_thread
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
while True:
retry += 1
try:
self._model = vllm.AsyncLLMEngine.from_engine_args(
self._underlying_model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args)
)
break
Expand Down Expand Up @@ -192,6 +193,9 @@ def __init__(
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
)

# Keep track of the event loop the engine was instantiated in.
self._event_loop = get_current_event_loop()

self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
self._hf_model_id
) # type:ignore
Expand All @@ -205,6 +209,24 @@ def __init__(
"outlines.models.vllm"
).adapt_tokenizer(self._tokenizer)

@property
def _model(self) -> vllm.AsyncLLMEngine:
"""Use model when making generation requests."""
el = get_current_event_loop()

# vLLM attaches itself to the event loop that is running when instantiated /
# the first generate request is made. Thankfully, they provide helpers to
# reset that. We do that here if the event loop changes.

# Most of the time, this should be a no-op. The event loop will only change
# if switching between async and sync calls.
if el != self._event_loop:
self._underlying_model.shutdown_background_loop()
self._underlying_model.start_background_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They call that a background_loop but it's not an event loop, it's actually a Future. Even the _background_loop_unshielded is a Task object.

I think it's fine to manage the reference to the event loop on our side. We only ever have the one AsyncLLMEngine per LocalVLLMBackend so there shouldn't be issues with us tracking it this way. Happy to change it if it causes issues later on.

self._event_loop = el

return self._underlying_model

def generate_from_context(
self,
action: Component | CBlock,
Expand Down Expand Up @@ -447,7 +469,8 @@ async def generate_all(prompts):
tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)]
return await asyncio.gather(*tasks)

decoded_results = asyncio.run(generate_all(prompts))
# Allow calling this from async functions.
decoded_results = _run_async_in_thread(generate_all(prompts))

results = [ModelOutputThunk(value=text) for text in decoded_results]

Expand Down