|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import abc |
| 9 | +import asyncio |
9 | 10 | import dataclasses |
10 | 11 | import datetime |
| 12 | +import functools |
11 | 13 | import inspect |
12 | 14 | import json |
13 | 15 | import os |
|
17 | 19 |
|
18 | 20 | import nest_asyncio |
19 | 21 | import sglang as sgl # type:ignore |
20 | | -import torch |
| 22 | +from sglang.utils import async_stream_and_merge # type:ignore |
21 | 23 | from transformers import AutoTokenizer, PreTrainedTokenizerBase |
22 | 24 |
|
23 | 25 | from mellea.backends import BaseModelSubclass |
|
30 | 32 | ) |
31 | 33 | from mellea.backends.types import ModelOption |
32 | 34 | from mellea.backends.utils import extract_model_tool_requests, to_chat |
| 35 | +from mellea.helpers.async_helpers import send_to_queue |
33 | 36 | from mellea.helpers.fancy_logger import FancyLogger |
34 | 37 | from mellea.stdlib.base import ( |
35 | 38 | CBlock, |
36 | 39 | Component, |
37 | 40 | Context, |
38 | 41 | GenerateLog, |
| 42 | + GenerateType, |
39 | 43 | ModelOutputThunk, |
40 | 44 | TemplateRepresentation, |
41 | 45 | ) |
@@ -148,7 +152,6 @@ def _generate_from_context_standard( |
148 | 152 | # Construct input. |
149 | 153 | # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. |
150 | 154 | # Otherwise, we will linearize the context and treat it as a raw input. |
151 | | - decoded_result: str | None = None |
152 | 155 |
|
153 | 156 | if ctx.is_chat_context: |
154 | 157 | system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None) |
@@ -185,42 +188,100 @@ def _generate_from_context_standard( |
185 | 188 | if format is not None: |
186 | 189 | sampling_params["json_schema"] = json.dumps(format.model_json_schema()) |
187 | 190 |
|
188 | | - output: dict[str, Any] = self._model.generate( # type: ignore |
189 | | - input_str, sampling_params=sampling_params |
190 | | - ) # type: ignore |
| 191 | + if model_options.get(ModelOption.STREAM, False): |
| 192 | + generator = async_stream_and_merge( |
| 193 | + self._model, # type: ignore |
| 194 | + input_str, |
| 195 | + sampling_params=sampling_params, |
| 196 | + ) # type: ignore |
| 197 | + else: |
| 198 | + generator = self._model.async_generate( # type: ignore |
| 199 | + input_str, sampling_params=sampling_params |
| 200 | + ) # type: ignore |
| 201 | + |
| 202 | + output = ModelOutputThunk(None) |
| 203 | + output._context = ctx.render_for_generation() |
| 204 | + output._action = action |
| 205 | + output._model_options = model_options |
| 206 | + |
| 207 | + output._process = self.processing |
| 208 | + output._post_process = functools.partial( |
| 209 | + self.post_processing, |
| 210 | + conversation=ctx_as_chat, |
| 211 | + tool_calls=tool_calls, |
| 212 | + tools=tools, |
| 213 | + seed=model_options.get(ModelOption.SEED, None), |
| 214 | + ) |
| 215 | + |
| 216 | + try: |
| 217 | + # This function should always be called from a running event loop so we don't have to worry about |
| 218 | + # scheduling the task to a specific event loop here. |
| 219 | + output._generate = asyncio.create_task( |
| 220 | + send_to_queue(generator, output._async_queue) # type: ignore |
| 221 | + ) |
| 222 | + output._generate_type = GenerateType.ASYNC |
| 223 | + except RuntimeError as e: |
| 224 | + # Most likely cause is running this function without an event loop present. |
| 225 | + raise e |
191 | 226 |
|
192 | | - decoded_result = output["text"] |
| 227 | + return output |
193 | 228 |
|
194 | 229 | else: |
195 | 230 | raise Exception("Does not yet support non-chat contexts.") |
196 | 231 |
|
197 | | - assert decoded_result is not None |
| 232 | + async def processing(self, mot: ModelOutputThunk, chunk: str | dict[str, Any]): |
| 233 | + """Process the returned chunks or the complete response.""" |
| 234 | + |
| 235 | + if isinstance(chunk, str): # via async_stream_and_merge |
| 236 | + if mot._underlying_value is None: |
| 237 | + mot._underlying_value = "" |
| 238 | + mot._underlying_value += chunk |
| 239 | + else: |
| 240 | + mot._underlying_value = chunk["text"] |
| 241 | + |
| 242 | + async def post_processing( |
| 243 | + self, |
| 244 | + mot: ModelOutputThunk, |
| 245 | + conversation: list[dict], |
| 246 | + tool_calls: bool, |
| 247 | + tools: dict[str, Callable], |
| 248 | + seed, |
| 249 | + ): |
| 250 | + """Called when generation is done.""" |
198 | 251 |
|
199 | | - result = ModelOutputThunk(value=decoded_result) |
| 252 | + # The ModelOutputThunk must be computed by this point. |
| 253 | + assert mot.value is not None |
200 | 254 |
|
201 | | - # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. |
| 255 | + # Only scan for tools if we are not doing structured output and tool calls were provided to the model. |
202 | 256 | if format is None and tool_calls: |
203 | | - result.tool_calls = extract_model_tool_requests(tools, decoded_result) |
| 257 | + mot.tool_calls = extract_model_tool_requests(tools, mot.value) |
204 | 258 |
|
205 | | - parsed_result = self.formatter.parse(action, result) |
206 | | - if generate_logs is not None: |
207 | | - assert isinstance(generate_logs, list) |
208 | | - generate_log = GenerateLog() |
209 | | - generate_log.prompt = ctx_as_chat |
210 | | - generate_log.backend = f"sglang::{self.model_id!s}" |
211 | | - generate_log.model_options = model_options |
212 | | - generate_log.date = datetime.datetime.now() |
213 | | - generate_log.model_output = decoded_result |
214 | | - generate_log.extra = { |
215 | | - "format": format, |
216 | | - "tools_available": tools, |
217 | | - "tools_called": result.tool_calls, |
218 | | - "seed": model_options.get(ModelOption.SEED, None), |
219 | | - } |
220 | | - generate_log.action = action |
221 | | - generate_log.result = parsed_result |
222 | | - generate_logs.append(generate_log) |
223 | | - return parsed_result |
| 259 | + assert mot._action is not None, ( |
| 260 | + "ModelOutputThunks should have their action assigned during generation" |
| 261 | + ) |
| 262 | + assert mot._model_options is not None, ( |
| 263 | + "ModelOutputThunks should have their model_opts assigned during generation" |
| 264 | + ) |
| 265 | + |
| 266 | + self.formatter.parse(mot._action, mot) |
| 267 | + |
| 268 | + # Generate the log for this ModelOutputThunk. |
| 269 | + generate_log = GenerateLog() |
| 270 | + generate_log.prompt = conversation |
| 271 | + generate_log.backend = f"sglang::{self.model_id!s}" |
| 272 | + generate_log.model_options = mot._model_options |
| 273 | + generate_log.date = datetime.datetime.now() |
| 274 | + generate_log.model_output = mot.value |
| 275 | + generate_log.extra = { |
| 276 | + "format": format, |
| 277 | + "tools_available": tools, |
| 278 | + "tools_called": mot.tool_calls, |
| 279 | + "seed": seed, |
| 280 | + } |
| 281 | + generate_log.action = mot._action |
| 282 | + generate_log.result = mot |
| 283 | + |
| 284 | + mot._generate_log = generate_log |
224 | 285 |
|
225 | 286 | def _generate_from_raw( |
226 | 287 | self, |
|
0 commit comments