Skip to content

Commit 61d7f0e

Browse files
feat: LiteLLM backend (#60)
* Started light LLM backend * change model options handling for litellm * change model_opts usage * remove a duplicate collection of tools * litellm as optional dependency * using new utility functions fixing model option cleanup * make litellm tests "qualitative" * fix tool extraction function according to #60 (comment) * fixing test format w.r.t. #60 (comment) * typo * fix merge --------- Co-authored-by: Jake LoRocco <[email protected]>
1 parent c8837c6 commit 61d7f0e

File tree

4 files changed

+531
-12
lines changed

4 files changed

+531
-12
lines changed

mellea/backends/litellm.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
"""A generic LiteLLM compatible backend that wraps around the openai python sdk."""
2+
3+
import datetime
4+
import json
5+
from collections.abc import Callable
6+
from typing import Any
7+
8+
import litellm
9+
import litellm.litellm_core_utils
10+
import litellm.litellm_core_utils.get_supported_openai_params
11+
12+
import mellea.backends.model_ids as model_ids
13+
from mellea.backends import BaseModelSubclass
14+
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
15+
from mellea.backends.tools import (
16+
add_tools_from_context_actions,
17+
add_tools_from_model_options,
18+
convert_tools_to_json,
19+
)
20+
from mellea.backends.types import ModelOption
21+
from mellea.helpers.fancy_logger import FancyLogger
22+
from mellea.stdlib.base import (
23+
CBlock,
24+
Component,
25+
Context,
26+
GenerateLog,
27+
ModelOutputThunk,
28+
ModelToolCall,
29+
)
30+
from mellea.stdlib.chat import Message
31+
from mellea.stdlib.requirement import ALoraRequirement
32+
33+
34+
class LiteLLMBackend(FormatterBackend):
35+
"""A generic LiteLLM compatible backend."""
36+
37+
def __init__(
38+
self,
39+
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_3_3_8B.ollama_name),
40+
formatter: Formatter | None = None,
41+
base_url: str | None = "http://localhost:11434",
42+
model_options: dict | None = None,
43+
):
44+
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
45+
46+
Args:
47+
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
48+
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
49+
base_url : Base url for LLM API. Defaults to None.
50+
model_options : Generation options to pass to the LLM. Defaults to None.
51+
"""
52+
super().__init__(
53+
model_id=model_id,
54+
formatter=(
55+
formatter
56+
if formatter is not None
57+
else TemplateFormatter(model_id=model_id)
58+
),
59+
model_options=model_options,
60+
)
61+
62+
assert isinstance(model_id, str), "Model ID must be a string."
63+
self._model_id = model_id
64+
65+
if base_url is None:
66+
self._base_url = "http://localhost:11434/v1" # ollama
67+
else:
68+
self._base_url = base_url
69+
70+
# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
71+
# These are usually values that must be extracted before hand or that are common among backend providers.
72+
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
73+
# users should only be specifying a single one in their request.
74+
self.to_mellea_model_opts_map = {
75+
"system": ModelOption.SYSTEM_PROMPT,
76+
"reasoning_effort": ModelOption.THINKING, # TODO: JAL; see which of these are actually extracted...
77+
"seed": ModelOption.SEED,
78+
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
79+
"max_tokens": ModelOption.MAX_NEW_TOKENS,
80+
"tools": ModelOption.TOOLS,
81+
"functions": ModelOption.TOOLS,
82+
}
83+
84+
# A mapping of Mellea specific ModelOptions to the specific names for this backend.
85+
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
86+
# Usually, values that are intentionally extracted while prepping for the backend generate call
87+
# will be omitted here so that they will be removed when model_options are processed
88+
# for the call to the model.
89+
self.from_mellea_model_opts_map = {
90+
ModelOption.SEED: "seed",
91+
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
92+
}
93+
94+
def generate_from_context(
95+
self,
96+
action: Component | CBlock,
97+
ctx: Context,
98+
*,
99+
format: type[BaseModelSubclass] | None = None,
100+
model_options: dict | None = None,
101+
generate_logs: list[GenerateLog] | None = None,
102+
tool_calls: bool = False,
103+
):
104+
"""See `generate_from_chat_context`."""
105+
assert ctx.is_chat_context, NotImplementedError(
106+
"The Openai backend only supports chat-like contexts."
107+
)
108+
return self._generate_from_chat_context_standard(
109+
action,
110+
ctx,
111+
format=format,
112+
model_options=model_options,
113+
generate_logs=generate_logs,
114+
tool_calls=tool_calls,
115+
)
116+
117+
def _simplify_and_merge(
118+
self, model_options: dict[str, Any] | None
119+
) -> dict[str, Any]:
120+
"""Simplifies model_options to use the Mellea specific ModelOption.Option and merges the backend's model_options with those passed into this call.
121+
122+
Rules:
123+
- Within a model_options dict, existing keys take precedence. This means remapping to mellea specific keys will maintain the value of the mellea specific key if one already exists.
124+
- When merging, the keys/values from the dictionary passed into this function take precedence.
125+
126+
Because this function simplifies and then merges, non-Mellea keys from the passed in model_options will replace
127+
Mellea specific keys from the backend's model_options.
128+
129+
Args:
130+
model_options: the model_options for this call
131+
132+
Returns:
133+
a new dict
134+
"""
135+
backend_model_opts = ModelOption.replace_keys(
136+
self.model_options, self.to_mellea_model_opts_map
137+
)
138+
139+
if model_options is None:
140+
return backend_model_opts
141+
142+
generate_call_model_opts = ModelOption.replace_keys(
143+
model_options, self.to_mellea_model_opts_map
144+
)
145+
return ModelOption.merge_model_options(
146+
backend_model_opts, generate_call_model_opts
147+
)
148+
149+
def _make_backend_specific_and_remove(
150+
self, model_options: dict[str, Any]
151+
) -> dict[str, Any]:
152+
"""Maps specified Mellea specific keys to their backend specific version and removes any remaining Mellea keys.
153+
154+
Additionally, logs any params unknown to litellm and any params that are openai specific but not supported by this model/provider.
155+
156+
Args:
157+
model_options: the model_options for this call
158+
159+
Returns:
160+
a new dict
161+
"""
162+
backend_specific = ModelOption.replace_keys(
163+
model_options, self.from_mellea_model_opts_map
164+
)
165+
backend_specific = ModelOption.remove_special_keys(backend_specific)
166+
167+
# We set `drop_params=True` which will drop non-supported openai params; check for non-openai
168+
# params that might cause errors and log which openai params aren't supported here.
169+
# See https://docs.litellm.ai/docs/completion/input.
170+
# standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
171+
supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params(
172+
self._model_id
173+
)
174+
supported_params = (
175+
set(supported_params_list) if supported_params_list is not None else set()
176+
)
177+
178+
# unknown_keys = [] # keys that are unknown to litellm
179+
unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider
180+
for key in backend_specific.keys():
181+
if key not in supported_params:
182+
unsupported_openai_params.append(key)
183+
184+
# if len(unknown_keys) > 0:
185+
# FancyLogger.get_logger().warning(
186+
# f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
187+
# )
188+
189+
if len(unsupported_openai_params) > 0:
190+
FancyLogger.get_logger().warning(
191+
f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}"
192+
)
193+
for key in unsupported_openai_params:
194+
del backend_specific[key]
195+
196+
return backend_specific
197+
198+
def _generate_from_chat_context_standard(
199+
self,
200+
action: Component | CBlock,
201+
ctx: Context,
202+
*,
203+
format: type[BaseModelSubclass]
204+
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
205+
model_options: dict | None = None,
206+
generate_logs: list[GenerateLog] | None = None,
207+
tool_calls: bool = False,
208+
) -> ModelOutputThunk:
209+
model_opts = self._simplify_and_merge(model_options)
210+
linearized_context = ctx.render_for_generation()
211+
assert linearized_context is not None, (
212+
"Cannot generate from a non-linear context in a FormatterBackend."
213+
)
214+
# Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this.
215+
messages: list[Message] = self.formatter.to_chat_messages(linearized_context)
216+
# Add the final message.
217+
match action:
218+
case ALoraRequirement():
219+
raise Exception("The LiteLLM backend does not support activated LoRAs.")
220+
case _:
221+
messages.extend(self.formatter.to_chat_messages([action]))
222+
223+
conversation: list[dict] = []
224+
system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "")
225+
if system_prompt != "":
226+
conversation.append({"role": "system", "content": system_prompt})
227+
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
228+
229+
if format is not None:
230+
response_format = {
231+
"type": "json_schema",
232+
"json_schema": {
233+
"name": format.__name__,
234+
"schema": format.model_json_schema(),
235+
"strict": True,
236+
},
237+
}
238+
else:
239+
response_format = {"type": "text"}
240+
241+
thinking = model_opts.get(ModelOption.THINKING, None)
242+
if type(thinking) is bool and thinking:
243+
# OpenAI uses strings for its reasoning levels.
244+
thinking = "medium"
245+
246+
# Append tool call information if applicable.
247+
tools = self._extract_tools(action, format, model_opts, tool_calls, ctx)
248+
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
249+
250+
model_specific_options = self._make_backend_specific_and_remove(model_opts)
251+
252+
chat_response: litellm.ModelResponse = litellm.completion(
253+
model=self._model_id,
254+
messages=conversation,
255+
tools=formatted_tools,
256+
response_format=response_format,
257+
reasoning_effort=thinking, # type: ignore
258+
drop_params=True, # See note in `_make_backend_specific_and_remove`.
259+
**model_specific_options,
260+
)
261+
262+
choice_0 = chat_response.choices[0]
263+
assert isinstance(choice_0, litellm.utils.Choices), (
264+
"Only works for non-streaming response for now"
265+
)
266+
result = ModelOutputThunk(
267+
value=choice_0.message.content,
268+
meta={
269+
"litellm_chat_response": chat_response.choices[0].model_dump()
270+
}, # NOTE: Using model dump here to comply with `TemplateFormatter`
271+
tool_calls=self._extract_model_tool_requests(tools, chat_response),
272+
)
273+
274+
parsed_result = self.formatter.parse(source_component=action, result=result)
275+
276+
if generate_logs is not None:
277+
assert isinstance(generate_logs, list)
278+
generate_log = GenerateLog()
279+
generate_log.prompt = conversation
280+
generate_log.backend = f"litellm::{self.model_id!s}"
281+
generate_log.model_options = model_specific_options
282+
generate_log.date = datetime.datetime.now()
283+
generate_log.model_output = chat_response
284+
generate_log.extra = {
285+
"format": format,
286+
"tools_available": tools,
287+
"tools_called": result.tool_calls,
288+
"seed": model_opts.get("seed", None),
289+
}
290+
generate_log.action = action
291+
generate_log.result = parsed_result
292+
generate_logs.append(generate_log)
293+
294+
return parsed_result
295+
296+
@staticmethod
297+
def _extract_tools(
298+
action, format, model_opts, tool_calls, ctx
299+
) -> dict[str, Callable]:
300+
tools: dict[str, Callable] = dict()
301+
if tool_calls:
302+
if format:
303+
FancyLogger.get_logger().warning(
304+
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
305+
)
306+
else:
307+
add_tools_from_model_options(tools, model_opts)
308+
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
309+
310+
# Add the tools from the action for this generation last so that
311+
# they overwrite conflicting names.
312+
add_tools_from_context_actions(tools, [action])
313+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
314+
return tools
315+
316+
def _generate_from_raw(
317+
self,
318+
actions: list[Component | CBlock],
319+
*,
320+
format: type[BaseModelSubclass] | None = None,
321+
model_options: dict | None = None,
322+
generate_logs: list[GenerateLog] | None = None,
323+
) -> list[ModelOutputThunk]:
324+
"""Generate using the completions api. Gives the input provided to the model without templating."""
325+
raise NotImplementedError("This method is not implemented yet.")
326+
327+
def _extract_model_tool_requests(
328+
self, tools: dict[str, Callable], chat_response: litellm.ModelResponse
329+
) -> dict[str, ModelToolCall] | None:
330+
model_tool_calls: dict[str, ModelToolCall] = {}
331+
choice_0 = chat_response.choices[0]
332+
assert isinstance(choice_0, litellm.utils.Choices), (
333+
"Only works for non-streaming response for now"
334+
)
335+
calls = choice_0.message.tool_calls
336+
if calls:
337+
for tool_call in calls:
338+
tool_name = str(tool_call.function.name)
339+
tool_args = tool_call.function.arguments
340+
341+
func = tools.get(tool_name)
342+
if func is None:
343+
FancyLogger.get_logger().warning(
344+
f"model attempted to call a non-existing function: {tool_name}"
345+
)
346+
continue # skip this function if we can't find it.
347+
348+
# Returns the args as a string. Parse it here.
349+
args = json.loads(tool_args)
350+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)
351+
352+
if len(model_tool_calls) > 0:
353+
return model_tool_calls
354+
return None

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,18 @@ hf = [
6262
"trl>=0.19.0",
6363
]
6464

65+
litellm = [
66+
"litellm>=1.76"
67+
]
68+
6569
watsonx = [
6670
"ibm-watsonx-ai>=1.3.31",
6771
]
6872
docling = [
6973
"docling>=2.45.0",
7074
]
7175

72-
all = ["mellea[watsonx,docling,hf]"]
76+
all = ["mellea[watsonx,docling,hf,litellm]"]
7377

7478
[dependency-groups]
7579
# Use these like:
@@ -140,7 +144,7 @@ ignore = [
140144
# "UP006", # List vs list, etc
141145
# "UP007", # Option and Union
142146
# "UP035", # `typing.Set` is deprecated, use `set` instead"
143-
"PD901", # Avoid using the generic variable name `df` for DataFrames
147+
"PD901", # Avoid using the generic variable name `df` for DataFrames
144148
]
145149

146150
[tool.ruff.lint.pydocstyle]

0 commit comments

Comments
 (0)