Skip to content

Commit cc9a691

Browse files
committed
feat(sglang): asynchronous support
1 parent 9b1cae6 commit cc9a691

File tree

1 file changed

+90
-29
lines changed

1 file changed

+90
-29
lines changed

mellea/backends/sglang.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from __future__ import annotations
77

88
import abc
9+
import asyncio
910
import dataclasses
1011
import datetime
12+
import functools
1113
import inspect
1214
import json
1315
import os
@@ -17,7 +19,7 @@
1719

1820
import nest_asyncio
1921
import sglang as sgl # type:ignore
20-
import torch
22+
from sglang.utils import async_stream_and_merge # type:ignore
2123
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2224

2325
from mellea.backends import BaseModelSubclass
@@ -30,12 +32,14 @@
3032
)
3133
from mellea.backends.types import ModelOption
3234
from mellea.backends.utils import extract_model_tool_requests, to_chat
35+
from mellea.helpers.async_helpers import send_to_queue
3336
from mellea.helpers.fancy_logger import FancyLogger
3437
from mellea.stdlib.base import (
3538
CBlock,
3639
Component,
3740
Context,
3841
GenerateLog,
42+
GenerateType,
3943
ModelOutputThunk,
4044
TemplateRepresentation,
4145
)
@@ -148,7 +152,6 @@ def _generate_from_context_standard(
148152
# Construct input.
149153
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
150154
# Otherwise, we will linearize the context and treat it as a raw input.
151-
decoded_result: str | None = None
152155

153156
if ctx.is_chat_context:
154157
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
@@ -185,42 +188,100 @@ def _generate_from_context_standard(
185188
if format is not None:
186189
sampling_params["json_schema"] = json.dumps(format.model_json_schema())
187190

188-
output: dict[str, Any] = self._model.generate( # type: ignore
189-
input_str, sampling_params=sampling_params
190-
) # type: ignore
191+
if model_options.get(ModelOption.STREAM, False):
192+
generator = async_stream_and_merge(
193+
self._model, # type: ignore
194+
input_str,
195+
sampling_params=sampling_params,
196+
) # type: ignore
197+
else:
198+
generator = self._model.async_generate( # type: ignore
199+
input_str, sampling_params=sampling_params
200+
) # type: ignore
201+
202+
output = ModelOutputThunk(None)
203+
output._context = ctx.render_for_generation()
204+
output._action = action
205+
output._model_options = model_options
206+
207+
output._process = self.processing
208+
output._post_process = functools.partial(
209+
self.post_processing,
210+
conversation=ctx_as_chat,
211+
tool_calls=tool_calls,
212+
tools=tools,
213+
seed=model_options.get(ModelOption.SEED, None),
214+
)
215+
216+
try:
217+
# This function should always be called from a running event loop so we don't have to worry about
218+
# scheduling the task to a specific event loop here.
219+
output._generate = asyncio.create_task(
220+
send_to_queue(generator, output._async_queue) # type: ignore
221+
)
222+
output._generate_type = GenerateType.ASYNC
223+
except RuntimeError as e:
224+
# Most likely cause is running this function without an event loop present.
225+
raise e
191226

192-
decoded_result = output["text"]
227+
return output
193228

194229
else:
195230
raise Exception("Does not yet support non-chat contexts.")
196231

197-
assert decoded_result is not None
232+
async def processing(self, mot: ModelOutputThunk, chunk: str | dict[str, Any]):
233+
"""Process the returned chunks or the complete response."""
234+
235+
if isinstance(chunk, str): # via async_stream_and_merge
236+
if mot._underlying_value is None:
237+
mot._underlying_value = ""
238+
mot._underlying_value += chunk
239+
else:
240+
mot._underlying_value = chunk["text"]
241+
242+
async def post_processing(
243+
self,
244+
mot: ModelOutputThunk,
245+
conversation: list[dict],
246+
tool_calls: bool,
247+
tools: dict[str, Callable],
248+
seed,
249+
):
250+
"""Called when generation is done."""
198251

199-
result = ModelOutputThunk(value=decoded_result)
252+
# The ModelOutputThunk must be computed by this point.
253+
assert mot.value is not None
200254

201-
# Only scan for tools if we are not doing structured decoding and tool calls were provided to the model.
255+
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
202256
if format is None and tool_calls:
203-
result.tool_calls = extract_model_tool_requests(tools, decoded_result)
257+
mot.tool_calls = extract_model_tool_requests(tools, mot.value)
204258

205-
parsed_result = self.formatter.parse(action, result)
206-
if generate_logs is not None:
207-
assert isinstance(generate_logs, list)
208-
generate_log = GenerateLog()
209-
generate_log.prompt = ctx_as_chat
210-
generate_log.backend = f"sglang::{self.model_id!s}"
211-
generate_log.model_options = model_options
212-
generate_log.date = datetime.datetime.now()
213-
generate_log.model_output = decoded_result
214-
generate_log.extra = {
215-
"format": format,
216-
"tools_available": tools,
217-
"tools_called": result.tool_calls,
218-
"seed": model_options.get(ModelOption.SEED, None),
219-
}
220-
generate_log.action = action
221-
generate_log.result = parsed_result
222-
generate_logs.append(generate_log)
223-
return parsed_result
259+
assert mot._action is not None, (
260+
"ModelOutputThunks should have their action assigned during generation"
261+
)
262+
assert mot._model_options is not None, (
263+
"ModelOutputThunks should have their model_opts assigned during generation"
264+
)
265+
266+
self.formatter.parse(mot._action, mot)
267+
268+
# Generate the log for this ModelOutputThunk.
269+
generate_log = GenerateLog()
270+
generate_log.prompt = conversation
271+
generate_log.backend = f"sglang::{self.model_id!s}"
272+
generate_log.model_options = mot._model_options
273+
generate_log.date = datetime.datetime.now()
274+
generate_log.model_output = mot.value
275+
generate_log.extra = {
276+
"format": format,
277+
"tools_available": tools,
278+
"tools_called": mot.tool_calls,
279+
"seed": seed,
280+
}
281+
generate_log.action = mot._action
282+
generate_log.result = mot
283+
284+
mot._generate_log = generate_log
224285

225286
def _generate_from_raw(
226287
self,

0 commit comments

Comments
 (0)