66from __future__ import annotations
77
88import abc
9+ import asyncio
910import dataclasses
1011import datetime
12+ import functools
13+ import importlib
1114import inspect
1215import json
1316import os
2023import outlines_core
2124import torch
2225import vllm # type:ignore
23- from transformers import PreTrainedTokenizerBase
26+ from transformers import AutoTokenizer , PreTrainedTokenizerBase
2427
2528from mellea .backends import BaseModelSubclass
2629from mellea .backends .formatter import Formatter , FormatterBackend , TemplateFormatter
3235)
3336from mellea .backends .types import ModelOption
3437from mellea .backends .utils import extract_model_tool_requests , to_chat
38+ from mellea .helpers .async_helpers import send_to_queue
3539from mellea .helpers .fancy_logger import FancyLogger
3640from mellea .stdlib .base import (
3741 CBlock ,
3842 Component ,
3943 Context ,
4044 GenerateLog ,
45+ GenerateType ,
4146 ModelOutputThunk ,
4247 TemplateRepresentation ,
4348)
@@ -114,7 +119,7 @@ def __init__(
114119 # vllm requires some model options during instantiation.
115120 engine_args = self ._simplify_and_merge (model_options )
116121 engine_args = self ._make_backend_specific_and_remove (
117- engine_args , vllm .EngineArgs
122+ engine_args , vllm .AsyncEngineArgs
118123 )
119124
120125 logger = FancyLogger .get_logger ()
@@ -135,8 +140,8 @@ def __init__(
135140 while True :
136141 retry += 1
137142 try :
138- self ._model = vllm .LLM (
139- model = self ._hf_model_id , task = "generate" , ** engine_args
143+ self ._model = vllm .AsyncLLMEngine . from_engine_args (
144+ vllm . AsyncEngineArgs ( model = self ._hf_model_id , ** engine_args )
140145 )
141146 break
142147 except torch ._dynamo .exc .BackendCompilerFailed as e :
@@ -187,12 +192,18 @@ def __init__(
187192 f"max_num_seqs: { engine_args ['max_num_seqs' ]} \n "
188193 )
189194
190- self ._tokenizer : PreTrainedTokenizerBase = self ._model .get_tokenizer () # type:ignore
195+ self ._tokenizer : PreTrainedTokenizerBase = AutoTokenizer .from_pretrained (
196+ self ._hf_model_id
197+ ) # type:ignore
191198
192- # see notes in outlines.models.vllm.adapt_tokenizer
193- self ._tokenizer_for_outlines : PreTrainedTokenizerBase = outlines .models .VLLM (
194- self ._model
195- ).tokenizer # type:ignore
199+ # See the notes in outlines.models.vllm.adapt_tokenizer for why this is needed.
200+ # Note: there is a module named outlines.models.vllm and a function named outlines.models.vllm.vllm .
201+ # However, outlines.models import outlines.models.vllm.vllm as vllm,
202+ # thus the module outlines.models.vllm becomes inaccessible,
203+ # hence the use of importlib to get the module.
204+ self ._tokenizer_for_outlines : PreTrainedTokenizerBase = importlib .import_module (
205+ "outlines.models.vllm"
206+ ).adapt_tokenizer (self ._tokenizer )
196207
197208 def generate_from_context (
198209 self ,
@@ -232,8 +243,6 @@ def _generate_from_context_standard(
232243 # Construct input.
233244 # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
234245 # Otherwise, we will linearize the context and treat it as a raw input.
235- decoded_result : str | None = None
236-
237246 if ctx .is_chat_context :
238247 system_prompt = model_options .get (ModelOption .SYSTEM_PROMPT , None )
239248 ctx_as_chat = to_chat (action , ctx , self .formatter , system_prompt )
@@ -265,7 +274,8 @@ def _generate_from_context_standard(
265274 sampling_params = vllm .SamplingParams (
266275 ** self ._make_backend_specific_and_remove (
267276 model_options , vllm .SamplingParams
268- )
277+ ),
278+ output_kind = vllm .sampling_params .RequestOutputKind .DELTA , # returns results incrementally
269279 )
270280
271281 if format is not None :
@@ -287,44 +297,95 @@ def _generate_from_context_standard(
287297 [logits_processor ] if logits_processor is not None else []
288298 )
289299
290- ros : list [vllm .RequestOutput ] = self ._model .generate ( # type: ignore
291- [input_str ], sampling_params = sampling_params
300+ # stream = model_options.get(ModelOption.STREAM, False)
301+ # if stream:
302+
303+ output = ModelOutputThunk (None )
304+
305+ generator = self ._model .generate ( # type: ignore
306+ request_id = str (id (output )),
307+ prompt = input_str ,
308+ sampling_params = sampling_params ,
292309 ) # type: ignore
293310
294- decoded_results = [ro .outputs [0 ].text for ro in ros ]
311+ output ._context = ctx .render_for_generation ()
312+ output ._action = action
313+ output ._model_options = model_options
314+
315+ output ._process = self .processing
316+ output ._post_process = functools .partial (
317+ self .post_processing ,
318+ conversation = ctx_as_chat ,
319+ tool_calls = tool_calls ,
320+ tools = tools ,
321+ seed = model_options .get (ModelOption .SEED , None ),
322+ )
323+
324+ try :
325+ # This function should always be called from a running event loop so we don't have to worry about
326+ # scheduling the task to a specific event loop here.
327+ output ._generate = asyncio .create_task (
328+ send_to_queue (generator , output ._async_queue ) # type: ignore
329+ )
330+ output ._generate_type = GenerateType .ASYNC
331+ except RuntimeError as e :
332+ # Most likely cause is running this function without an event loop present.
333+ raise e
295334
296- decoded_result = decoded_results [ 0 ]
335+ return output
297336
298337 else :
299338 raise Exception ("Does not yet support non-chat contexts." )
300339
301- assert decoded_result is not None
340+ async def processing (self , mot : ModelOutputThunk , chunk : vllm .RequestOutput ):
341+ """Process the returned chunks or the complete response."""
342+ if mot ._underlying_value is None :
343+ mot ._underlying_value = ""
344+ mot ._underlying_value += chunk .outputs [0 ].text
345+
346+ async def post_processing (
347+ self ,
348+ mot : ModelOutputThunk ,
349+ conversation : list [dict ],
350+ tool_calls : bool ,
351+ tools : dict [str , Callable ],
352+ seed ,
353+ ):
354+ """Called when generation is done."""
302355
303- result = ModelOutputThunk (value = decoded_result )
356+ # The ModelOutputThunk must be computed by this point.
357+ assert mot .value is not None
304358
305- # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model.
359+ # Only scan for tools if we are not doing structured output and tool calls were provided to the model.
306360 if format is None and tool_calls :
307- result .tool_calls = extract_model_tool_requests (tools , decoded_result )
361+ mot .tool_calls = extract_model_tool_requests (tools , mot . value )
308362
309- parsed_result = self .formatter .parse (action , result )
310- if generate_logs is not None :
311- assert isinstance (generate_logs , list )
312- generate_log = GenerateLog ()
313- generate_log .prompt = ctx_as_chat
314- generate_log .backend = f"vllm::{ self .model_id !s} "
315- generate_log .model_options = model_options
316- generate_log .date = datetime .datetime .now ()
317- generate_log .model_output = decoded_result
318- generate_log .extra = {
319- "format" : format ,
320- "tools_available" : tools ,
321- "tools_called" : result .tool_calls ,
322- "seed" : model_options .get (ModelOption .SEED , None ),
323- }
324- generate_log .action = action
325- generate_log .result = parsed_result
326- generate_logs .append (generate_log )
327- return parsed_result
363+ assert mot ._action is not None , (
364+ "ModelOutputThunks should have their action assigned during generation"
365+ )
366+ assert mot ._model_options is not None , (
367+ "ModelOutputThunks should have their model_opts assigned during generation"
368+ )
369+
370+ self .formatter .parse (mot ._action , mot )
371+
372+ # Generate the log for this ModelOutputThunk.
373+ generate_log = GenerateLog ()
374+ generate_log .prompt = conversation
375+ generate_log .backend = f"vllm::{ self .model_id !s} "
376+ generate_log .model_options = mot ._model_options
377+ generate_log .date = datetime .datetime .now ()
378+ generate_log .model_output = mot .value
379+ generate_log .extra = {
380+ "format" : format ,
381+ "tools_available" : tools ,
382+ "tools_called" : mot .tool_calls ,
383+ "seed" : seed ,
384+ }
385+ generate_log .action = mot ._action
386+ generate_log .result = mot
387+
388+ mot ._generate_log = generate_log
328389
329390 def _generate_from_raw (
330391 self ,
@@ -340,7 +401,10 @@ def _generate_from_raw(
340401 prompts = [self .formatter .print (action ) for action in actions ]
341402
342403 sampling_params = vllm .SamplingParams (
343- ** self ._make_backend_specific_and_remove (model_options , vllm .SamplingParams )
404+ ** self ._make_backend_specific_and_remove (
405+ model_options , vllm .SamplingParams
406+ ),
407+ output_kind = vllm .sampling_params .RequestOutputKind .FINAL_ONLY , # returns only the final results
344408 )
345409
346410 if format is not None :
@@ -360,11 +424,18 @@ def _generate_from_raw(
360424 [logits_processor ] if logits_processor is not None else []
361425 )
362426
363- ros : list [vllm .RequestOutput ] = self ._model .generate ( # type: ignore
364- prompts , sampling_params = sampling_params
365- ) # type: ignore
427+ async def generate (prompt , request_id ):
428+ async for result_output in self ._model .generate (
429+ request_id = request_id , prompt = prompt , sampling_params = sampling_params
430+ ):
431+ assert result_output .finished
432+ return result_output .outputs [0 ].text
433+
434+ async def generate_all (prompts ):
435+ tasks = [generate (p , f"{ id (prompts )} -{ i } " ) for i , p in enumerate (prompts )]
436+ return await asyncio .gather (* tasks )
366437
367- decoded_results = [ ro . outputs [ 0 ]. text for ro in ros ]
438+ decoded_results = asyncio . run ( generate_all ( prompts ))
368439
369440 results = [ModelOutputThunk (value = text ) for text in decoded_results ]
370441
0 commit comments