|
| 1 | +"""A generic LiteLLM compatible backend that wraps around the openai python sdk.""" |
| 2 | + |
| 3 | +import datetime |
| 4 | +import json |
| 5 | +from collections.abc import Callable |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +import litellm |
| 9 | +import litellm.litellm_core_utils |
| 10 | +import litellm.litellm_core_utils.get_supported_openai_params |
| 11 | + |
| 12 | +import mellea.backends.model_ids as model_ids |
| 13 | +from mellea.backends import BaseModelSubclass |
| 14 | +from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter |
| 15 | +from mellea.backends.tools import ( |
| 16 | + add_tools_from_context_actions, |
| 17 | + add_tools_from_model_options, |
| 18 | + convert_tools_to_json, |
| 19 | +) |
| 20 | +from mellea.backends.types import ModelOption |
| 21 | +from mellea.helpers.fancy_logger import FancyLogger |
| 22 | +from mellea.stdlib.base import ( |
| 23 | + CBlock, |
| 24 | + Component, |
| 25 | + Context, |
| 26 | + GenerateLog, |
| 27 | + ModelOutputThunk, |
| 28 | + ModelToolCall, |
| 29 | +) |
| 30 | +from mellea.stdlib.chat import Message |
| 31 | +from mellea.stdlib.requirement import ALoraRequirement |
| 32 | + |
| 33 | + |
| 34 | +class LiteLLMBackend(FormatterBackend): |
| 35 | + """A generic LiteLLM compatible backend.""" |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_3_3_8B.ollama_name), |
| 40 | + formatter: Formatter | None = None, |
| 41 | + base_url: str | None = "http://localhost:11434", |
| 42 | + model_options: dict | None = None, |
| 43 | + ): |
| 44 | + """Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs. |
| 45 | +
|
| 46 | + Args: |
| 47 | + model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables. |
| 48 | + formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter |
| 49 | + base_url : Base url for LLM API. Defaults to None. |
| 50 | + model_options : Generation options to pass to the LLM. Defaults to None. |
| 51 | + """ |
| 52 | + super().__init__( |
| 53 | + model_id=model_id, |
| 54 | + formatter=( |
| 55 | + formatter |
| 56 | + if formatter is not None |
| 57 | + else TemplateFormatter(model_id=model_id) |
| 58 | + ), |
| 59 | + model_options=model_options, |
| 60 | + ) |
| 61 | + |
| 62 | + assert isinstance(model_id, str), "Model ID must be a string." |
| 63 | + self._model_id = model_id |
| 64 | + |
| 65 | + if base_url is None: |
| 66 | + self._base_url = "http://localhost:11434/v1" # ollama |
| 67 | + else: |
| 68 | + self._base_url = base_url |
| 69 | + |
| 70 | + # A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent. |
| 71 | + # These are usually values that must be extracted before hand or that are common among backend providers. |
| 72 | + # OpenAI has some deprecated parameters. Those map to the same mellea parameter, but |
| 73 | + # users should only be specifying a single one in their request. |
| 74 | + self.to_mellea_model_opts_map = { |
| 75 | + "system": ModelOption.SYSTEM_PROMPT, |
| 76 | + "reasoning_effort": ModelOption.THINKING, # TODO: JAL; see which of these are actually extracted... |
| 77 | + "seed": ModelOption.SEED, |
| 78 | + "max_completion_tokens": ModelOption.MAX_NEW_TOKENS, |
| 79 | + "max_tokens": ModelOption.MAX_NEW_TOKENS, |
| 80 | + "tools": ModelOption.TOOLS, |
| 81 | + "functions": ModelOption.TOOLS, |
| 82 | + } |
| 83 | + |
| 84 | + # A mapping of Mellea specific ModelOptions to the specific names for this backend. |
| 85 | + # These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`. |
| 86 | + # Usually, values that are intentionally extracted while prepping for the backend generate call |
| 87 | + # will be omitted here so that they will be removed when model_options are processed |
| 88 | + # for the call to the model. |
| 89 | + self.from_mellea_model_opts_map = { |
| 90 | + ModelOption.SEED: "seed", |
| 91 | + ModelOption.MAX_NEW_TOKENS: "max_completion_tokens", |
| 92 | + } |
| 93 | + |
| 94 | + def generate_from_context( |
| 95 | + self, |
| 96 | + action: Component | CBlock, |
| 97 | + ctx: Context, |
| 98 | + *, |
| 99 | + format: type[BaseModelSubclass] | None = None, |
| 100 | + model_options: dict | None = None, |
| 101 | + generate_logs: list[GenerateLog] | None = None, |
| 102 | + tool_calls: bool = False, |
| 103 | + ): |
| 104 | + """See `generate_from_chat_context`.""" |
| 105 | + assert ctx.is_chat_context, NotImplementedError( |
| 106 | + "The Openai backend only supports chat-like contexts." |
| 107 | + ) |
| 108 | + return self._generate_from_chat_context_standard( |
| 109 | + action, |
| 110 | + ctx, |
| 111 | + format=format, |
| 112 | + model_options=model_options, |
| 113 | + generate_logs=generate_logs, |
| 114 | + tool_calls=tool_calls, |
| 115 | + ) |
| 116 | + |
| 117 | + def _simplify_and_merge( |
| 118 | + self, model_options: dict[str, Any] | None |
| 119 | + ) -> dict[str, Any]: |
| 120 | + """Simplifies model_options to use the Mellea specific ModelOption.Option and merges the backend's model_options with those passed into this call. |
| 121 | +
|
| 122 | + Rules: |
| 123 | + - Within a model_options dict, existing keys take precedence. This means remapping to mellea specific keys will maintain the value of the mellea specific key if one already exists. |
| 124 | + - When merging, the keys/values from the dictionary passed into this function take precedence. |
| 125 | +
|
| 126 | + Because this function simplifies and then merges, non-Mellea keys from the passed in model_options will replace |
| 127 | + Mellea specific keys from the backend's model_options. |
| 128 | +
|
| 129 | + Args: |
| 130 | + model_options: the model_options for this call |
| 131 | +
|
| 132 | + Returns: |
| 133 | + a new dict |
| 134 | + """ |
| 135 | + backend_model_opts = ModelOption.replace_keys( |
| 136 | + self.model_options, self.to_mellea_model_opts_map |
| 137 | + ) |
| 138 | + |
| 139 | + if model_options is None: |
| 140 | + return backend_model_opts |
| 141 | + |
| 142 | + generate_call_model_opts = ModelOption.replace_keys( |
| 143 | + model_options, self.to_mellea_model_opts_map |
| 144 | + ) |
| 145 | + return ModelOption.merge_model_options( |
| 146 | + backend_model_opts, generate_call_model_opts |
| 147 | + ) |
| 148 | + |
| 149 | + def _make_backend_specific_and_remove( |
| 150 | + self, model_options: dict[str, Any] |
| 151 | + ) -> dict[str, Any]: |
| 152 | + """Maps specified Mellea specific keys to their backend specific version and removes any remaining Mellea keys. |
| 153 | +
|
| 154 | + Additionally, logs any params unknown to litellm and any params that are openai specific but not supported by this model/provider. |
| 155 | +
|
| 156 | + Args: |
| 157 | + model_options: the model_options for this call |
| 158 | +
|
| 159 | + Returns: |
| 160 | + a new dict |
| 161 | + """ |
| 162 | + backend_specific = ModelOption.replace_keys( |
| 163 | + model_options, self.from_mellea_model_opts_map |
| 164 | + ) |
| 165 | + backend_specific = ModelOption.remove_special_keys(backend_specific) |
| 166 | + |
| 167 | + # We set `drop_params=True` which will drop non-supported openai params; check for non-openai |
| 168 | + # params that might cause errors and log which openai params aren't supported here. |
| 169 | + # See https://docs.litellm.ai/docs/completion/input. |
| 170 | + # standard_openai_subset = litellm.get_standard_openai_params(backend_specific) |
| 171 | + supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params( |
| 172 | + self._model_id |
| 173 | + ) |
| 174 | + supported_params = ( |
| 175 | + set(supported_params_list) if supported_params_list is not None else set() |
| 176 | + ) |
| 177 | + |
| 178 | + # unknown_keys = [] # keys that are unknown to litellm |
| 179 | + unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider |
| 180 | + for key in backend_specific.keys(): |
| 181 | + if key not in supported_params: |
| 182 | + unsupported_openai_params.append(key) |
| 183 | + |
| 184 | + # if len(unknown_keys) > 0: |
| 185 | + # FancyLogger.get_logger().warning( |
| 186 | + # f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}" |
| 187 | + # ) |
| 188 | + |
| 189 | + if len(unsupported_openai_params) > 0: |
| 190 | + FancyLogger.get_logger().warning( |
| 191 | + f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}" |
| 192 | + ) |
| 193 | + for key in unsupported_openai_params: |
| 194 | + del backend_specific[key] |
| 195 | + |
| 196 | + return backend_specific |
| 197 | + |
| 198 | + def _generate_from_chat_context_standard( |
| 199 | + self, |
| 200 | + action: Component | CBlock, |
| 201 | + ctx: Context, |
| 202 | + *, |
| 203 | + format: type[BaseModelSubclass] |
| 204 | + | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel |
| 205 | + model_options: dict | None = None, |
| 206 | + generate_logs: list[GenerateLog] | None = None, |
| 207 | + tool_calls: bool = False, |
| 208 | + ) -> ModelOutputThunk: |
| 209 | + model_opts = self._simplify_and_merge(model_options) |
| 210 | + linearized_context = ctx.render_for_generation() |
| 211 | + assert linearized_context is not None, ( |
| 212 | + "Cannot generate from a non-linear context in a FormatterBackend." |
| 213 | + ) |
| 214 | + # Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this. |
| 215 | + messages: list[Message] = self.formatter.to_chat_messages(linearized_context) |
| 216 | + # Add the final message. |
| 217 | + match action: |
| 218 | + case ALoraRequirement(): |
| 219 | + raise Exception("The LiteLLM backend does not support activated LoRAs.") |
| 220 | + case _: |
| 221 | + messages.extend(self.formatter.to_chat_messages([action])) |
| 222 | + |
| 223 | + conversation: list[dict] = [] |
| 224 | + system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "") |
| 225 | + if system_prompt != "": |
| 226 | + conversation.append({"role": "system", "content": system_prompt}) |
| 227 | + conversation.extend([{"role": m.role, "content": m.content} for m in messages]) |
| 228 | + |
| 229 | + if format is not None: |
| 230 | + response_format = { |
| 231 | + "type": "json_schema", |
| 232 | + "json_schema": { |
| 233 | + "name": format.__name__, |
| 234 | + "schema": format.model_json_schema(), |
| 235 | + "strict": True, |
| 236 | + }, |
| 237 | + } |
| 238 | + else: |
| 239 | + response_format = {"type": "text"} |
| 240 | + |
| 241 | + thinking = model_opts.get(ModelOption.THINKING, None) |
| 242 | + if type(thinking) is bool and thinking: |
| 243 | + # OpenAI uses strings for its reasoning levels. |
| 244 | + thinking = "medium" |
| 245 | + |
| 246 | + # Append tool call information if applicable. |
| 247 | + tools = self._extract_tools(action, format, model_opts, tool_calls, ctx) |
| 248 | + formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None |
| 249 | + |
| 250 | + model_specific_options = self._make_backend_specific_and_remove(model_opts) |
| 251 | + |
| 252 | + chat_response: litellm.ModelResponse = litellm.completion( |
| 253 | + model=self._model_id, |
| 254 | + messages=conversation, |
| 255 | + tools=formatted_tools, |
| 256 | + response_format=response_format, |
| 257 | + reasoning_effort=thinking, # type: ignore |
| 258 | + drop_params=True, # See note in `_make_backend_specific_and_remove`. |
| 259 | + **model_specific_options, |
| 260 | + ) |
| 261 | + |
| 262 | + choice_0 = chat_response.choices[0] |
| 263 | + assert isinstance(choice_0, litellm.utils.Choices), ( |
| 264 | + "Only works for non-streaming response for now" |
| 265 | + ) |
| 266 | + result = ModelOutputThunk( |
| 267 | + value=choice_0.message.content, |
| 268 | + meta={ |
| 269 | + "litellm_chat_response": chat_response.choices[0].model_dump() |
| 270 | + }, # NOTE: Using model dump here to comply with `TemplateFormatter` |
| 271 | + tool_calls=self._extract_model_tool_requests(tools, chat_response), |
| 272 | + ) |
| 273 | + |
| 274 | + parsed_result = self.formatter.parse(source_component=action, result=result) |
| 275 | + |
| 276 | + if generate_logs is not None: |
| 277 | + assert isinstance(generate_logs, list) |
| 278 | + generate_log = GenerateLog() |
| 279 | + generate_log.prompt = conversation |
| 280 | + generate_log.backend = f"litellm::{self.model_id!s}" |
| 281 | + generate_log.model_options = model_specific_options |
| 282 | + generate_log.date = datetime.datetime.now() |
| 283 | + generate_log.model_output = chat_response |
| 284 | + generate_log.extra = { |
| 285 | + "format": format, |
| 286 | + "tools_available": tools, |
| 287 | + "tools_called": result.tool_calls, |
| 288 | + "seed": model_opts.get("seed", None), |
| 289 | + } |
| 290 | + generate_log.action = action |
| 291 | + generate_log.result = parsed_result |
| 292 | + generate_logs.append(generate_log) |
| 293 | + |
| 294 | + return parsed_result |
| 295 | + |
| 296 | + @staticmethod |
| 297 | + def _extract_tools( |
| 298 | + action, format, model_opts, tool_calls, ctx |
| 299 | + ) -> dict[str, Callable]: |
| 300 | + tools: dict[str, Callable] = dict() |
| 301 | + if tool_calls: |
| 302 | + if format: |
| 303 | + FancyLogger.get_logger().warning( |
| 304 | + f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" |
| 305 | + ) |
| 306 | + else: |
| 307 | + add_tools_from_model_options(tools, model_opts) |
| 308 | + add_tools_from_context_actions(tools, ctx.actions_for_available_tools()) |
| 309 | + |
| 310 | + # Add the tools from the action for this generation last so that |
| 311 | + # they overwrite conflicting names. |
| 312 | + add_tools_from_context_actions(tools, [action]) |
| 313 | + FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") |
| 314 | + return tools |
| 315 | + |
| 316 | + def _generate_from_raw( |
| 317 | + self, |
| 318 | + actions: list[Component | CBlock], |
| 319 | + *, |
| 320 | + format: type[BaseModelSubclass] | None = None, |
| 321 | + model_options: dict | None = None, |
| 322 | + generate_logs: list[GenerateLog] | None = None, |
| 323 | + ) -> list[ModelOutputThunk]: |
| 324 | + """Generate using the completions api. Gives the input provided to the model without templating.""" |
| 325 | + raise NotImplementedError("This method is not implemented yet.") |
| 326 | + |
| 327 | + def _extract_model_tool_requests( |
| 328 | + self, tools: dict[str, Callable], chat_response: litellm.ModelResponse |
| 329 | + ) -> dict[str, ModelToolCall] | None: |
| 330 | + model_tool_calls: dict[str, ModelToolCall] = {} |
| 331 | + choice_0 = chat_response.choices[0] |
| 332 | + assert isinstance(choice_0, litellm.utils.Choices), ( |
| 333 | + "Only works for non-streaming response for now" |
| 334 | + ) |
| 335 | + calls = choice_0.message.tool_calls |
| 336 | + if calls: |
| 337 | + for tool_call in calls: |
| 338 | + tool_name = str(tool_call.function.name) |
| 339 | + tool_args = tool_call.function.arguments |
| 340 | + |
| 341 | + func = tools.get(tool_name) |
| 342 | + if func is None: |
| 343 | + FancyLogger.get_logger().warning( |
| 344 | + f"model attempted to call a non-existing function: {tool_name}" |
| 345 | + ) |
| 346 | + continue # skip this function if we can't find it. |
| 347 | + |
| 348 | + # Returns the args as a string. Parse it here. |
| 349 | + args = json.loads(tool_args) |
| 350 | + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) |
| 351 | + |
| 352 | + if len(model_tool_calls) > 0: |
| 353 | + return model_tool_calls |
| 354 | + return None |
0 commit comments