|
| 1 | +"""A generic OpenAI compatible backend that wraps around the openai python sdk.""" |
| 2 | + |
| 3 | +import datetime |
| 4 | +import json |
| 5 | +from collections.abc import Callable |
| 6 | + |
| 7 | +import litellm |
| 8 | + |
| 9 | +import mellea.backends.model_ids as model_ids |
| 10 | +from mellea.backends import BaseModelSubclass |
| 11 | +from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter |
| 12 | +from mellea.backends.tools import convert_tools_to_json, get_tools_from_action |
| 13 | +from mellea.backends.types import ModelOption |
| 14 | +from mellea.helpers.fancy_logger import FancyLogger |
| 15 | +from mellea.stdlib.base import ( |
| 16 | + CBlock, |
| 17 | + Component, |
| 18 | + Context, |
| 19 | + GenerateLog, |
| 20 | + ModelOutputThunk, |
| 21 | + ModelToolCall, |
| 22 | + TemplateRepresentation, |
| 23 | +) |
| 24 | +from mellea.stdlib.chat import Message |
| 25 | +from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement |
| 26 | + |
| 27 | + |
| 28 | +class LiteLLMBackend(FormatterBackend): |
| 29 | + """A generic LiteLLM compatible backend.""" |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_3_3_8B.ollama_name), |
| 34 | + formatter: Formatter | None = None, |
| 35 | + base_url: str | None = "http://localhost:11434", |
| 36 | + model_options: dict | None = None, |
| 37 | + ): |
| 38 | + """Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs. |
| 39 | +
|
| 40 | + Args: |
| 41 | + model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables. |
| 42 | + formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter |
| 43 | + base_url : Base url for LLM API. Defaults to None. |
| 44 | + model_options : Generation options to pass to the LLM. Defaults to None. |
| 45 | + """ |
| 46 | + super().__init__( |
| 47 | + model_id=model_id, |
| 48 | + formatter=( |
| 49 | + formatter |
| 50 | + if formatter is not None |
| 51 | + else TemplateFormatter(model_id=model_id) |
| 52 | + ), |
| 53 | + model_options=model_options, |
| 54 | + ) |
| 55 | + |
| 56 | + assert isinstance(model_id, str), "Model ID must be a string." |
| 57 | + self._model_id = model_id |
| 58 | + |
| 59 | + if base_url is None: |
| 60 | + self._base_url = "http://localhost:11434/v1" # ollama |
| 61 | + else: |
| 62 | + self._base_url = base_url |
| 63 | + |
| 64 | + def generate_from_context( |
| 65 | + self, |
| 66 | + action: Component | CBlock, |
| 67 | + ctx: Context, |
| 68 | + *, |
| 69 | + format: type[BaseModelSubclass] | None = None, |
| 70 | + model_options: dict | None = None, |
| 71 | + generate_logs: list[GenerateLog] | None = None, |
| 72 | + tool_calls: bool = False, |
| 73 | + ): |
| 74 | + """See `generate_from_chat_context`.""" |
| 75 | + assert ctx.is_chat_context, NotImplementedError( |
| 76 | + "The Openai backend only supports chat-like contexts." |
| 77 | + ) |
| 78 | + return self._generate_from_chat_context_standard( |
| 79 | + action, |
| 80 | + ctx, |
| 81 | + format=format, |
| 82 | + model_options=model_options, |
| 83 | + generate_logs=generate_logs, |
| 84 | + tool_calls=tool_calls, |
| 85 | + ) |
| 86 | + |
| 87 | + def _simplify_and_merge(self, mo: dict) -> dict: |
| 88 | + mo_safe = {} if mo is None else mo.copy() |
| 89 | + mo_merged = ModelOption.merge_model_options(self.model_options, mo_safe) |
| 90 | + |
| 91 | + # map to valid litellm names |
| 92 | + mo_mapping = { |
| 93 | + ModelOption.TOOLS: "tools", |
| 94 | + ModelOption.MAX_NEW_TOKENS: "max_completion_tokens", |
| 95 | + ModelOption.SEED: "seed", |
| 96 | + ModelOption.THINKING: "thinking", |
| 97 | + } |
| 98 | + mo_res = ModelOption.replace_keys(mo_merged, mo_mapping) |
| 99 | + mo_res = ModelOption.remove_special_keys(mo_res) |
| 100 | + |
| 101 | + supported_params = litellm.get_supported_openai_params(self._model_id) |
| 102 | + assert supported_params is not None |
| 103 | + for k in list(mo_res.keys()): |
| 104 | + if k not in supported_params: |
| 105 | + del mo_res[k] |
| 106 | + FancyLogger.get_logger().warn( |
| 107 | + f"Skipping '{k}' -- Model-Option not supported by {self.model_id}." |
| 108 | + ) |
| 109 | + |
| 110 | + return mo_res |
| 111 | + |
| 112 | + def _generate_from_chat_context_standard( |
| 113 | + self, |
| 114 | + action: Component | CBlock, |
| 115 | + ctx: Context, |
| 116 | + *, |
| 117 | + format: type[BaseModelSubclass] |
| 118 | + | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel |
| 119 | + model_options: dict | None = None, |
| 120 | + generate_logs: list[GenerateLog] | None = None, |
| 121 | + tool_calls: bool = False, |
| 122 | + ) -> ModelOutputThunk: |
| 123 | + model_options = {} if model_options is None else model_options |
| 124 | + model_opts = self._simplify_and_merge(model_options) |
| 125 | + linearized_context = ctx.linearize() |
| 126 | + assert linearized_context is not None, ( |
| 127 | + "Cannot generate from a non-linear context in a FormatterBackend." |
| 128 | + ) |
| 129 | + # Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this. |
| 130 | + messages: list[Message] = self.formatter.to_chat_messages(linearized_context) |
| 131 | + # Add the final message. |
| 132 | + match action: |
| 133 | + case ALoraRequirement(): |
| 134 | + raise Exception("The LiteLLM backend does not support activated LoRAs.") |
| 135 | + case _: |
| 136 | + messages.extend(self.formatter.to_chat_messages([action])) |
| 137 | + |
| 138 | + conversation: list[dict] = [] |
| 139 | + system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, "") |
| 140 | + if system_prompt != "": |
| 141 | + conversation.append({"role": "system", "content": system_prompt}) |
| 142 | + conversation.extend([{"role": m.role, "content": m.content} for m in messages]) |
| 143 | + |
| 144 | + if format is not None: |
| 145 | + response_format = { |
| 146 | + "type": "json_schema", |
| 147 | + "json_schema": { |
| 148 | + "name": format.__name__, |
| 149 | + "schema": format.model_json_schema(), |
| 150 | + "strict": True, |
| 151 | + }, |
| 152 | + } |
| 153 | + else: |
| 154 | + response_format = {"type": "text"} |
| 155 | + |
| 156 | + # Append tool call information if applicable. |
| 157 | + tools = self._extract_tools(action, format, model_opts, tool_calls) |
| 158 | + formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None |
| 159 | + |
| 160 | + chat_response: litellm.ModelResponse = litellm.completion( |
| 161 | + model=self._model_id, |
| 162 | + messages=conversation, |
| 163 | + tools=formatted_tools, |
| 164 | + response_format=response_format, |
| 165 | + **model_opts, |
| 166 | + ) |
| 167 | + |
| 168 | + choice_0 = chat_response.choices[0] |
| 169 | + assert isinstance(choice_0, litellm.utils.Choices), ( |
| 170 | + "Only works for non-streaming response for now" |
| 171 | + ) |
| 172 | + result = ModelOutputThunk( |
| 173 | + value=choice_0.message.content, |
| 174 | + meta={ |
| 175 | + "litellm_chat_response": chat_response.choices[0].model_dump() |
| 176 | + }, # NOTE: Using model dump here to comply with `TemplateFormatter` |
| 177 | + tool_calls=self._extract_model_tool_requests(tools, chat_response), |
| 178 | + ) |
| 179 | + |
| 180 | + parsed_result = self.formatter.parse(source_component=action, result=result) |
| 181 | + |
| 182 | + if generate_logs is not None: |
| 183 | + assert isinstance(generate_logs, list) |
| 184 | + generate_log = GenerateLog() |
| 185 | + generate_log.prompt = conversation |
| 186 | + generate_log.backend = f"litellm::{self.model_id!s}" |
| 187 | + generate_log.model_options = model_opts |
| 188 | + generate_log.date = datetime.datetime.now() |
| 189 | + generate_log.model_output = chat_response |
| 190 | + generate_log.extra = { |
| 191 | + "format": format, |
| 192 | + "tools_available": tools, |
| 193 | + "tools_called": result.tool_calls, |
| 194 | + "seed": model_opts.get("seed", None), |
| 195 | + } |
| 196 | + generate_log.action = action |
| 197 | + generate_log.result = parsed_result |
| 198 | + generate_logs.append(generate_log) |
| 199 | + |
| 200 | + return parsed_result |
| 201 | + |
| 202 | + @staticmethod |
| 203 | + def _extract_tools(action, format, model_opts, tool_calls): |
| 204 | + tools: dict[str, Callable] = dict() |
| 205 | + if tool_calls: |
| 206 | + if format: |
| 207 | + FancyLogger.get_logger().warning( |
| 208 | + f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" |
| 209 | + ) |
| 210 | + else: |
| 211 | + if isinstance(action, Component) and isinstance( |
| 212 | + action.format_for_llm(), TemplateRepresentation |
| 213 | + ): |
| 214 | + tools = get_tools_from_action(action) |
| 215 | + |
| 216 | + model_options_tools = model_opts.get(ModelOption.TOOLS, None) |
| 217 | + if model_options_tools is not None: |
| 218 | + assert isinstance(model_options_tools, dict) |
| 219 | + for fn_name in model_options_tools: |
| 220 | + # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools |
| 221 | + assert fn_name not in tools.keys(), ( |
| 222 | + f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." |
| 223 | + ) |
| 224 | + # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. |
| 225 | + assert type(fn_name) is str, ( |
| 226 | + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." |
| 227 | + ) |
| 228 | + assert callable(model_options_tools[fn_name]), ( |
| 229 | + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." |
| 230 | + ) |
| 231 | + # Add the model_options tool to the existing set of tools. |
| 232 | + tools[fn_name] = model_options_tools[fn_name] |
| 233 | + return tools |
| 234 | + |
| 235 | + def _generate_from_raw( |
| 236 | + self, |
| 237 | + actions: list[Component | CBlock], |
| 238 | + *, |
| 239 | + format: type[BaseModelSubclass] | None = None, |
| 240 | + model_options: dict | None = None, |
| 241 | + generate_logs: list[GenerateLog] | None = None, |
| 242 | + ) -> list[ModelOutputThunk]: |
| 243 | + """Generate using the completions api. Gives the input provided to the model without templating.""" |
| 244 | + raise NotImplementedError("This method is not implemented yet.") |
| 245 | + # extra_body = {} |
| 246 | + # if format is not None: |
| 247 | + # FancyLogger.get_logger().warning( |
| 248 | + # "The official OpenAI completion api does not accept response format / structured decoding; " |
| 249 | + # "it will be passed as an extra arg." |
| 250 | + # ) |
| 251 | + # |
| 252 | + # # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. |
| 253 | + # extra_body["guided_json"] = format.model_json_schema() |
| 254 | + # |
| 255 | + # model_opts = self._simplify_and_merge(model_options, is_chat_context=False) |
| 256 | + # |
| 257 | + # prompts = [self.formatter.print(action) for action in actions] |
| 258 | + # |
| 259 | + # try: |
| 260 | + # completion_response: Completion = self._client.completions.create( |
| 261 | + # model=self._hf_model_id, |
| 262 | + # prompt=prompts, |
| 263 | + # extra_body=extra_body, |
| 264 | + # **self._make_backend_specific_and_remove( |
| 265 | + # model_opts, is_chat_context=False |
| 266 | + # ), |
| 267 | + # ) # type: ignore |
| 268 | + # except openai.BadRequestError as e: |
| 269 | + # if openai_ollama_batching_error in e.message: |
| 270 | + # FancyLogger.get_logger().error( |
| 271 | + # "If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, " |
| 272 | + # "your requests will fail since ollama doesn't support batching requests." |
| 273 | + # ) |
| 274 | + # raise e |
| 275 | + # |
| 276 | + # # Necessary for type checker. |
| 277 | + # assert isinstance(completion_response, Completion) |
| 278 | + # |
| 279 | + # results = [ |
| 280 | + # ModelOutputThunk( |
| 281 | + # value=response.text, |
| 282 | + # meta={"oai_completion_response": response.model_dump()}, |
| 283 | + # ) |
| 284 | + # for response in completion_response.choices |
| 285 | + # ] |
| 286 | + # |
| 287 | + # for i, result in enumerate(results): |
| 288 | + # self.formatter.parse(actions[i], result) |
| 289 | + # |
| 290 | + # if generate_logs is not None: |
| 291 | + # assert isinstance(generate_logs, list) |
| 292 | + # date = datetime.datetime.now() |
| 293 | + # |
| 294 | + # for i in range(len(prompts)): |
| 295 | + # generate_log = GenerateLog() |
| 296 | + # generate_log.prompt = prompts[i] |
| 297 | + # generate_log.backend = f"openai::{self.model_id!s}" |
| 298 | + # generate_log.model_options = model_opts |
| 299 | + # generate_log.date = date |
| 300 | + # generate_log.model_output = completion_response |
| 301 | + # generate_log.extra = {"seed": model_opts.get("seed", None)} |
| 302 | + # generate_log.action = actions[i] |
| 303 | + # generate_log.result = results[i] |
| 304 | + # generate_logs.append(generate_log) |
| 305 | + # |
| 306 | + # return results |
| 307 | + |
| 308 | + def _extract_model_tool_requests( |
| 309 | + self, tools: dict[str, Callable], chat_response: litellm.ModelResponse |
| 310 | + ) -> dict[str, ModelToolCall] | None: |
| 311 | + model_tool_calls: dict[str, ModelToolCall] = {} |
| 312 | + choice_0 = chat_response.choices[0] |
| 313 | + assert isinstance(choice_0, litellm.utils.Choices), ( |
| 314 | + "Only works for non-streaming response for now" |
| 315 | + ) |
| 316 | + calls = choice_0.message.tool_calls |
| 317 | + if calls: |
| 318 | + for tool_call in calls: |
| 319 | + tool_name = str(tool_call.function.name) |
| 320 | + tool_args = tool_call.function.arguments |
| 321 | + |
| 322 | + func = tools.get(tool_name) |
| 323 | + if func is None: |
| 324 | + FancyLogger.get_logger().warning( |
| 325 | + f"model attempted to call a non-existing function: {tool_name}" |
| 326 | + ) |
| 327 | + continue # skip this function if we can't find it. |
| 328 | + |
| 329 | + # Returns the args as a string. Parse it here. |
| 330 | + args = json.loads(tool_args) |
| 331 | + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) |
| 332 | + |
| 333 | + if len(model_tool_calls) > 0: |
| 334 | + return model_tool_calls |
| 335 | + return None |
0 commit comments