Skip to content

Commit 69b92f0

Browse files
committed
feat(sglang): asynchronous support
1 parent 443fbc7 commit 69b92f0

File tree

1 file changed

+90
-29
lines changed

1 file changed

+90
-29
lines changed

mellea/backends/sglang.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from __future__ import annotations
77

88
import abc
9+
import asyncio
910
import dataclasses
1011
import datetime
12+
import functools
1113
import inspect
1214
import json
1315
import os
@@ -16,7 +18,7 @@
1618
from typing import TYPE_CHECKING, Any, Optional
1719

1820
import sglang as sgl # type:ignore
19-
import torch
21+
from sglang.utils import async_stream_and_merge # type:ignore
2022
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2123

2224
from mellea.backends import BaseModelSubclass
@@ -29,12 +31,14 @@
2931
)
3032
from mellea.backends.types import ModelOption
3133
from mellea.backends.utils import extract_model_tool_requests, to_chat
34+
from mellea.helpers.async_helpers import send_to_queue
3235
from mellea.helpers.fancy_logger import FancyLogger
3336
from mellea.stdlib.base import (
3437
CBlock,
3538
Component,
3639
Context,
3740
GenerateLog,
41+
GenerateType,
3842
ModelOutputThunk,
3943
TemplateRepresentation,
4044
)
@@ -145,7 +149,6 @@ def _generate_from_context_standard(
145149
# Construct input.
146150
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
147151
# Otherwise, we will linearize the context and treat it as a raw input.
148-
decoded_result: str | None = None
149152

150153
if ctx.is_chat_context:
151154
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
@@ -182,42 +185,100 @@ def _generate_from_context_standard(
182185
if format is not None:
183186
sampling_params["json_schema"] = json.dumps(format.model_json_schema())
184187

185-
output: dict[str, Any] = self._model.generate( # type: ignore
186-
input_str, sampling_params=sampling_params
187-
) # type: ignore
188+
if model_options.get(ModelOption.STREAM, False):
189+
generator = async_stream_and_merge(
190+
self._model, # type: ignore
191+
input_str,
192+
sampling_params=sampling_params,
193+
) # type: ignore
194+
else:
195+
generator = self._model.async_generate( # type: ignore
196+
input_str, sampling_params=sampling_params
197+
) # type: ignore
198+
199+
output = ModelOutputThunk(None)
200+
output._context = ctx.render_for_generation()
201+
output._action = action
202+
output._model_options = model_options
203+
204+
output._process = self.processing
205+
output._post_process = functools.partial(
206+
self.post_processing,
207+
conversation=ctx_as_chat,
208+
tool_calls=tool_calls,
209+
tools=tools,
210+
seed=model_options.get(ModelOption.SEED, None),
211+
)
212+
213+
try:
214+
# This function should always be called from a running event loop so we don't have to worry about
215+
# scheduling the task to a specific event loop here.
216+
output._generate = asyncio.create_task(
217+
send_to_queue(generator, output._async_queue) # type: ignore
218+
)
219+
output._generate_type = GenerateType.ASYNC
220+
except RuntimeError as e:
221+
# Most likely cause is running this function without an event loop present.
222+
raise e
188223

189-
decoded_result = output["text"]
224+
return output
190225

191226
else:
192227
raise Exception("Does not yet support non-chat contexts.")
193228

194-
assert decoded_result is not None
229+
async def processing(self, mot: ModelOutputThunk, chunk: str | dict[str, Any]):
230+
"""Process the returned chunks or the complete response."""
231+
232+
if isinstance(chunk, str): # via async_stream_and_merge
233+
if mot._underlying_value is None:
234+
mot._underlying_value = ""
235+
mot._underlying_value += chunk
236+
else:
237+
mot._underlying_value = chunk["text"]
238+
239+
async def post_processing(
240+
self,
241+
mot: ModelOutputThunk,
242+
conversation: list[dict],
243+
tool_calls: bool,
244+
tools: dict[str, Callable],
245+
seed,
246+
):
247+
"""Called when generation is done."""
195248

196-
result = ModelOutputThunk(value=decoded_result)
249+
# The ModelOutputThunk must be computed by this point.
250+
assert mot.value is not None
197251

198-
# Only scan for tools if we are not doing structured decoding and tool calls were provided to the model.
252+
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
199253
if format is None and tool_calls:
200-
result.tool_calls = extract_model_tool_requests(tools, decoded_result)
254+
mot.tool_calls = extract_model_tool_requests(tools, mot.value)
201255

202-
parsed_result = self.formatter.parse(action, result)
203-
if generate_logs is not None:
204-
assert isinstance(generate_logs, list)
205-
generate_log = GenerateLog()
206-
generate_log.prompt = ctx_as_chat
207-
generate_log.backend = f"sglang::{self.model_id!s}"
208-
generate_log.model_options = model_options
209-
generate_log.date = datetime.datetime.now()
210-
generate_log.model_output = decoded_result
211-
generate_log.extra = {
212-
"format": format,
213-
"tools_available": tools,
214-
"tools_called": result.tool_calls,
215-
"seed": model_options.get(ModelOption.SEED, None),
216-
}
217-
generate_log.action = action
218-
generate_log.result = parsed_result
219-
generate_logs.append(generate_log)
220-
return parsed_result
256+
assert mot._action is not None, (
257+
"ModelOutputThunks should have their action assigned during generation"
258+
)
259+
assert mot._model_options is not None, (
260+
"ModelOutputThunks should have their model_opts assigned during generation"
261+
)
262+
263+
self.formatter.parse(mot._action, mot)
264+
265+
# Generate the log for this ModelOutputThunk.
266+
generate_log = GenerateLog()
267+
generate_log.prompt = conversation
268+
generate_log.backend = f"sglang::{self.model_id!s}"
269+
generate_log.model_options = mot._model_options
270+
generate_log.date = datetime.datetime.now()
271+
generate_log.model_output = mot.value
272+
generate_log.extra = {
273+
"format": format,
274+
"tools_available": tools,
275+
"tools_called": mot.tool_calls,
276+
"seed": seed,
277+
}
278+
generate_log.action = mot._action
279+
generate_log.result = mot
280+
281+
mot._generate_log = generate_log
221282

222283
def _generate_from_raw(
223284
self,

0 commit comments

Comments
 (0)