|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import abc |
| 9 | +import asyncio |
9 | 10 | import dataclasses |
10 | 11 | import datetime |
| 12 | +import functools |
11 | 13 | import inspect |
12 | 14 | import json |
13 | 15 | import os |
|
16 | 18 | from typing import TYPE_CHECKING, Any, Optional |
17 | 19 |
|
18 | 20 | import sglang as sgl # type:ignore |
19 | | -import torch |
| 21 | +from sglang.utils import async_stream_and_merge # type:ignore |
20 | 22 | from transformers import AutoTokenizer, PreTrainedTokenizerBase |
21 | 23 |
|
22 | 24 | from mellea.backends import BaseModelSubclass |
|
29 | 31 | ) |
30 | 32 | from mellea.backends.types import ModelOption |
31 | 33 | from mellea.backends.utils import extract_model_tool_requests, to_chat |
| 34 | +from mellea.helpers.async_helpers import send_to_queue |
32 | 35 | from mellea.helpers.fancy_logger import FancyLogger |
33 | 36 | from mellea.stdlib.base import ( |
34 | 37 | CBlock, |
35 | 38 | Component, |
36 | 39 | Context, |
37 | 40 | GenerateLog, |
| 41 | + GenerateType, |
38 | 42 | ModelOutputThunk, |
39 | 43 | TemplateRepresentation, |
40 | 44 | ) |
@@ -145,7 +149,6 @@ def _generate_from_context_standard( |
145 | 149 | # Construct input. |
146 | 150 | # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. |
147 | 151 | # Otherwise, we will linearize the context and treat it as a raw input. |
148 | | - decoded_result: str | None = None |
149 | 152 |
|
150 | 153 | if ctx.is_chat_context: |
151 | 154 | system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None) |
@@ -182,42 +185,100 @@ def _generate_from_context_standard( |
182 | 185 | if format is not None: |
183 | 186 | sampling_params["json_schema"] = json.dumps(format.model_json_schema()) |
184 | 187 |
|
185 | | - output: dict[str, Any] = self._model.generate( # type: ignore |
186 | | - input_str, sampling_params=sampling_params |
187 | | - ) # type: ignore |
| 188 | + if model_options.get(ModelOption.STREAM, False): |
| 189 | + generator = async_stream_and_merge( |
| 190 | + self._model, # type: ignore |
| 191 | + input_str, |
| 192 | + sampling_params=sampling_params, |
| 193 | + ) # type: ignore |
| 194 | + else: |
| 195 | + generator = self._model.async_generate( # type: ignore |
| 196 | + input_str, sampling_params=sampling_params |
| 197 | + ) # type: ignore |
| 198 | + |
| 199 | + output = ModelOutputThunk(None) |
| 200 | + output._context = ctx.render_for_generation() |
| 201 | + output._action = action |
| 202 | + output._model_options = model_options |
| 203 | + |
| 204 | + output._process = self.processing |
| 205 | + output._post_process = functools.partial( |
| 206 | + self.post_processing, |
| 207 | + conversation=ctx_as_chat, |
| 208 | + tool_calls=tool_calls, |
| 209 | + tools=tools, |
| 210 | + seed=model_options.get(ModelOption.SEED, None), |
| 211 | + ) |
| 212 | + |
| 213 | + try: |
| 214 | + # This function should always be called from a running event loop so we don't have to worry about |
| 215 | + # scheduling the task to a specific event loop here. |
| 216 | + output._generate = asyncio.create_task( |
| 217 | + send_to_queue(generator, output._async_queue) # type: ignore |
| 218 | + ) |
| 219 | + output._generate_type = GenerateType.ASYNC |
| 220 | + except RuntimeError as e: |
| 221 | + # Most likely cause is running this function without an event loop present. |
| 222 | + raise e |
188 | 223 |
|
189 | | - decoded_result = output["text"] |
| 224 | + return output |
190 | 225 |
|
191 | 226 | else: |
192 | 227 | raise Exception("Does not yet support non-chat contexts.") |
193 | 228 |
|
194 | | - assert decoded_result is not None |
| 229 | + async def processing(self, mot: ModelOutputThunk, chunk: str | dict[str, Any]): |
| 230 | + """Process the returned chunks or the complete response.""" |
| 231 | + |
| 232 | + if isinstance(chunk, str): # via async_stream_and_merge |
| 233 | + if mot._underlying_value is None: |
| 234 | + mot._underlying_value = "" |
| 235 | + mot._underlying_value += chunk |
| 236 | + else: |
| 237 | + mot._underlying_value = chunk["text"] |
| 238 | + |
| 239 | + async def post_processing( |
| 240 | + self, |
| 241 | + mot: ModelOutputThunk, |
| 242 | + conversation: list[dict], |
| 243 | + tool_calls: bool, |
| 244 | + tools: dict[str, Callable], |
| 245 | + seed, |
| 246 | + ): |
| 247 | + """Called when generation is done.""" |
195 | 248 |
|
196 | | - result = ModelOutputThunk(value=decoded_result) |
| 249 | + # The ModelOutputThunk must be computed by this point. |
| 250 | + assert mot.value is not None |
197 | 251 |
|
198 | | - # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. |
| 252 | + # Only scan for tools if we are not doing structured output and tool calls were provided to the model. |
199 | 253 | if format is None and tool_calls: |
200 | | - result.tool_calls = extract_model_tool_requests(tools, decoded_result) |
| 254 | + mot.tool_calls = extract_model_tool_requests(tools, mot.value) |
201 | 255 |
|
202 | | - parsed_result = self.formatter.parse(action, result) |
203 | | - if generate_logs is not None: |
204 | | - assert isinstance(generate_logs, list) |
205 | | - generate_log = GenerateLog() |
206 | | - generate_log.prompt = ctx_as_chat |
207 | | - generate_log.backend = f"sglang::{self.model_id!s}" |
208 | | - generate_log.model_options = model_options |
209 | | - generate_log.date = datetime.datetime.now() |
210 | | - generate_log.model_output = decoded_result |
211 | | - generate_log.extra = { |
212 | | - "format": format, |
213 | | - "tools_available": tools, |
214 | | - "tools_called": result.tool_calls, |
215 | | - "seed": model_options.get(ModelOption.SEED, None), |
216 | | - } |
217 | | - generate_log.action = action |
218 | | - generate_log.result = parsed_result |
219 | | - generate_logs.append(generate_log) |
220 | | - return parsed_result |
| 256 | + assert mot._action is not None, ( |
| 257 | + "ModelOutputThunks should have their action assigned during generation" |
| 258 | + ) |
| 259 | + assert mot._model_options is not None, ( |
| 260 | + "ModelOutputThunks should have their model_opts assigned during generation" |
| 261 | + ) |
| 262 | + |
| 263 | + self.formatter.parse(mot._action, mot) |
| 264 | + |
| 265 | + # Generate the log for this ModelOutputThunk. |
| 266 | + generate_log = GenerateLog() |
| 267 | + generate_log.prompt = conversation |
| 268 | + generate_log.backend = f"sglang::{self.model_id!s}" |
| 269 | + generate_log.model_options = mot._model_options |
| 270 | + generate_log.date = datetime.datetime.now() |
| 271 | + generate_log.model_output = mot.value |
| 272 | + generate_log.extra = { |
| 273 | + "format": format, |
| 274 | + "tools_available": tools, |
| 275 | + "tools_called": mot.tool_calls, |
| 276 | + "seed": seed, |
| 277 | + } |
| 278 | + generate_log.action = mot._action |
| 279 | + generate_log.result = mot |
| 280 | + |
| 281 | + mot._generate_log = generate_log |
221 | 282 |
|
222 | 283 | def _generate_from_raw( |
223 | 284 | self, |
|
0 commit comments