Skip to content

Commit 9310eaa

Browse files
Started light LLM backend
1 parent a0a945b commit 9310eaa

File tree

4 files changed

+463
-4
lines changed

4 files changed

+463
-4
lines changed

mellea/backends/litellm.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""A generic OpenAI compatible backend that wraps around the openai python sdk."""
2+
3+
import datetime
4+
import json
5+
from collections.abc import Callable
6+
7+
import litellm
8+
9+
import mellea.backends.model_ids as model_ids
10+
from mellea.backends import BaseModelSubclass
11+
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
12+
from mellea.backends.tools import convert_tools_to_json, get_tools_from_action
13+
from mellea.backends.types import ModelOption
14+
from mellea.helpers.fancy_logger import FancyLogger
15+
from mellea.stdlib.base import (
16+
CBlock,
17+
Component,
18+
Context,
19+
GenerateLog,
20+
ModelOutputThunk,
21+
ModelToolCall,
22+
TemplateRepresentation,
23+
)
24+
from mellea.stdlib.chat import Message
25+
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
26+
27+
28+
class LiteLLMBackend(FormatterBackend):
29+
"""A generic LiteLLM compatible backend."""
30+
31+
def __init__(
32+
self,
33+
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_3_3_8B.ollama_name),
34+
formatter: Formatter | None = None,
35+
base_url: str | None = "http://localhost:11434",
36+
model_options: dict | None = None,
37+
):
38+
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
39+
40+
Args:
41+
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
42+
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
43+
base_url : Base url for LLM API. Defaults to None.
44+
model_options : Generation options to pass to the LLM. Defaults to None.
45+
"""
46+
super().__init__(
47+
model_id=model_id,
48+
formatter=(
49+
formatter
50+
if formatter is not None
51+
else TemplateFormatter(model_id=model_id)
52+
),
53+
model_options=model_options,
54+
)
55+
56+
assert isinstance(model_id, str), "Model ID must be a string."
57+
self._model_id = model_id
58+
59+
if base_url is None:
60+
self._base_url = "http://localhost:11434/v1" # ollama
61+
else:
62+
self._base_url = base_url
63+
64+
def generate_from_context(
65+
self,
66+
action: Component | CBlock,
67+
ctx: Context,
68+
*,
69+
format: type[BaseModelSubclass] | None = None,
70+
model_options: dict | None = None,
71+
generate_logs: list[GenerateLog] | None = None,
72+
tool_calls: bool = False,
73+
):
74+
"""See `generate_from_chat_context`."""
75+
assert ctx.is_chat_context, NotImplementedError(
76+
"The Openai backend only supports chat-like contexts."
77+
)
78+
return self._generate_from_chat_context_standard(
79+
action,
80+
ctx,
81+
format=format,
82+
model_options=model_options,
83+
generate_logs=generate_logs,
84+
tool_calls=tool_calls,
85+
)
86+
87+
def _simplify_and_merge(self, mo: dict) -> dict:
88+
mo_safe = {} if mo is None else mo.copy()
89+
mo_merged = ModelOption.merge_model_options(self.model_options, mo_safe)
90+
91+
# map to valid litellm names
92+
mo_mapping = {
93+
ModelOption.TOOLS: "tools",
94+
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
95+
ModelOption.SEED: "seed",
96+
ModelOption.THINKING: "thinking",
97+
}
98+
mo_res = ModelOption.replace_keys(mo_merged, mo_mapping)
99+
mo_res = ModelOption.remove_special_keys(mo_res)
100+
101+
supported_params = litellm.get_supported_openai_params(self._model_id)
102+
assert supported_params is not None
103+
for k in list(mo_res.keys()):
104+
if k not in supported_params:
105+
del mo_res[k]
106+
FancyLogger.get_logger().warn(
107+
f"Skipping '{k}' -- Model-Option not supported by {self.model_id}."
108+
)
109+
110+
return mo_res
111+
112+
def _generate_from_chat_context_standard(
113+
self,
114+
action: Component | CBlock,
115+
ctx: Context,
116+
*,
117+
format: type[BaseModelSubclass]
118+
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
119+
model_options: dict | None = None,
120+
generate_logs: list[GenerateLog] | None = None,
121+
tool_calls: bool = False,
122+
) -> ModelOutputThunk:
123+
model_options = {} if model_options is None else model_options
124+
model_opts = self._simplify_and_merge(model_options)
125+
linearized_context = ctx.linearize()
126+
assert linearized_context is not None, (
127+
"Cannot generate from a non-linear context in a FormatterBackend."
128+
)
129+
# Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this.
130+
messages: list[Message] = self.formatter.to_chat_messages(linearized_context)
131+
# Add the final message.
132+
match action:
133+
case ALoraRequirement():
134+
raise Exception("The LiteLLM backend does not support activated LoRAs.")
135+
case _:
136+
messages.extend(self.formatter.to_chat_messages([action]))
137+
138+
conversation: list[dict] = []
139+
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, "")
140+
if system_prompt != "":
141+
conversation.append({"role": "system", "content": system_prompt})
142+
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
143+
144+
if format is not None:
145+
response_format = {
146+
"type": "json_schema",
147+
"json_schema": {
148+
"name": format.__name__,
149+
"schema": format.model_json_schema(),
150+
"strict": True,
151+
},
152+
}
153+
else:
154+
response_format = {"type": "text"}
155+
156+
# Append tool call information if applicable.
157+
tools = self._extract_tools(action, format, model_opts, tool_calls)
158+
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
159+
160+
chat_response: litellm.ModelResponse = litellm.completion(
161+
model=self._model_id,
162+
messages=conversation,
163+
tools=formatted_tools,
164+
response_format=response_format,
165+
**model_opts,
166+
)
167+
168+
choice_0 = chat_response.choices[0]
169+
assert isinstance(choice_0, litellm.utils.Choices), (
170+
"Only works for non-streaming response for now"
171+
)
172+
result = ModelOutputThunk(
173+
value=choice_0.message.content,
174+
meta={
175+
"litellm_chat_response": chat_response.choices[0].model_dump()
176+
}, # NOTE: Using model dump here to comply with `TemplateFormatter`
177+
tool_calls=self._extract_model_tool_requests(tools, chat_response),
178+
)
179+
180+
parsed_result = self.formatter.parse(source_component=action, result=result)
181+
182+
if generate_logs is not None:
183+
assert isinstance(generate_logs, list)
184+
generate_log = GenerateLog()
185+
generate_log.prompt = conversation
186+
generate_log.backend = f"litellm::{self.model_id!s}"
187+
generate_log.model_options = model_opts
188+
generate_log.date = datetime.datetime.now()
189+
generate_log.model_output = chat_response
190+
generate_log.extra = {
191+
"format": format,
192+
"tools_available": tools,
193+
"tools_called": result.tool_calls,
194+
"seed": model_opts.get("seed", None),
195+
}
196+
generate_log.action = action
197+
generate_log.result = parsed_result
198+
generate_logs.append(generate_log)
199+
200+
return parsed_result
201+
202+
@staticmethod
203+
def _extract_tools(action, format, model_opts, tool_calls):
204+
tools: dict[str, Callable] = dict()
205+
if tool_calls:
206+
if format:
207+
FancyLogger.get_logger().warning(
208+
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
209+
)
210+
else:
211+
if isinstance(action, Component) and isinstance(
212+
action.format_for_llm(), TemplateRepresentation
213+
):
214+
tools = get_tools_from_action(action)
215+
216+
model_options_tools = model_opts.get(ModelOption.TOOLS, None)
217+
if model_options_tools is not None:
218+
assert isinstance(model_options_tools, dict)
219+
for fn_name in model_options_tools:
220+
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
221+
assert fn_name not in tools.keys(), (
222+
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
223+
)
224+
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
225+
assert type(fn_name) is str, (
226+
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
227+
)
228+
assert callable(model_options_tools[fn_name]), (
229+
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
230+
)
231+
# Add the model_options tool to the existing set of tools.
232+
tools[fn_name] = model_options_tools[fn_name]
233+
return tools
234+
235+
def _generate_from_raw(
236+
self,
237+
actions: list[Component | CBlock],
238+
*,
239+
format: type[BaseModelSubclass] | None = None,
240+
model_options: dict | None = None,
241+
generate_logs: list[GenerateLog] | None = None,
242+
) -> list[ModelOutputThunk]:
243+
"""Generate using the completions api. Gives the input provided to the model without templating."""
244+
raise NotImplementedError("This method is not implemented yet.")
245+
# extra_body = {}
246+
# if format is not None:
247+
# FancyLogger.get_logger().warning(
248+
# "The official OpenAI completion api does not accept response format / structured decoding; "
249+
# "it will be passed as an extra arg."
250+
# )
251+
#
252+
# # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests.
253+
# extra_body["guided_json"] = format.model_json_schema()
254+
#
255+
# model_opts = self._simplify_and_merge(model_options, is_chat_context=False)
256+
#
257+
# prompts = [self.formatter.print(action) for action in actions]
258+
#
259+
# try:
260+
# completion_response: Completion = self._client.completions.create(
261+
# model=self._hf_model_id,
262+
# prompt=prompts,
263+
# extra_body=extra_body,
264+
# **self._make_backend_specific_and_remove(
265+
# model_opts, is_chat_context=False
266+
# ),
267+
# ) # type: ignore
268+
# except openai.BadRequestError as e:
269+
# if openai_ollama_batching_error in e.message:
270+
# FancyLogger.get_logger().error(
271+
# "If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
272+
# "your requests will fail since ollama doesn't support batching requests."
273+
# )
274+
# raise e
275+
#
276+
# # Necessary for type checker.
277+
# assert isinstance(completion_response, Completion)
278+
#
279+
# results = [
280+
# ModelOutputThunk(
281+
# value=response.text,
282+
# meta={"oai_completion_response": response.model_dump()},
283+
# )
284+
# for response in completion_response.choices
285+
# ]
286+
#
287+
# for i, result in enumerate(results):
288+
# self.formatter.parse(actions[i], result)
289+
#
290+
# if generate_logs is not None:
291+
# assert isinstance(generate_logs, list)
292+
# date = datetime.datetime.now()
293+
#
294+
# for i in range(len(prompts)):
295+
# generate_log = GenerateLog()
296+
# generate_log.prompt = prompts[i]
297+
# generate_log.backend = f"openai::{self.model_id!s}"
298+
# generate_log.model_options = model_opts
299+
# generate_log.date = date
300+
# generate_log.model_output = completion_response
301+
# generate_log.extra = {"seed": model_opts.get("seed", None)}
302+
# generate_log.action = actions[i]
303+
# generate_log.result = results[i]
304+
# generate_logs.append(generate_log)
305+
#
306+
# return results
307+
308+
def _extract_model_tool_requests(
309+
self, tools: dict[str, Callable], chat_response: litellm.ModelResponse
310+
) -> dict[str, ModelToolCall] | None:
311+
model_tool_calls: dict[str, ModelToolCall] = {}
312+
choice_0 = chat_response.choices[0]
313+
assert isinstance(choice_0, litellm.utils.Choices), (
314+
"Only works for non-streaming response for now"
315+
)
316+
calls = choice_0.message.tool_calls
317+
if calls:
318+
for tool_call in calls:
319+
tool_name = str(tool_call.function.name)
320+
tool_args = tool_call.function.arguments
321+
322+
func = tools.get(tool_name)
323+
if func is None:
324+
FancyLogger.get_logger().warning(
325+
f"model attempted to call a non-existing function: {tool_name}"
326+
)
327+
continue # skip this function if we can't find it.
328+
329+
# Returns the args as a string. Parse it here.
330+
args = json.loads(tool_args)
331+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)
332+
333+
if len(model_tool_calls) > 0:
334+
return model_tool_calls
335+
return None

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ dependencies = [
4949
"mistletoe>=1.4.0",
5050
"trl==0.19.0",
5151
"peft",
52-
"torch"
52+
"torch",
53+
"litellm>=1.75.5.post1",
5354
]
5455

5556
[project.scripts]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import mellea
2+
from mellea import MelleaSession
3+
from mellea.backends import ModelOption
4+
from mellea.backends.litellm import LiteLLMBackend
5+
from mellea.stdlib.chat import Message
6+
from mellea.stdlib.sampling import RejectionSamplingStrategy
7+
8+
9+
class TestLitellmOllama:
10+
m = MelleaSession(LiteLLMBackend())
11+
12+
def test_litellm_ollama_chat(self):
13+
res = self.m.chat("hello world")
14+
assert res is not None
15+
assert isinstance(res, Message)
16+
17+
def test_litellm_ollama_instruct(self):
18+
res = self.m.instruct(
19+
"Write an email to the interns.",
20+
requirements=["be funny"],
21+
strategy=RejectionSamplingStrategy(loop_budget=3)
22+
)
23+
assert res is not None
24+
assert isinstance(res.value, str)
25+
26+
def test_litellm_ollama_instruct_options(self):
27+
res = self.m.instruct(
28+
"Write an email to the interns.",
29+
requirements=["be funny"],
30+
model_options={
31+
ModelOption.SEED: 123,
32+
ModelOption.TEMPERATURE: .5,
33+
ModelOption.THINKING:True,
34+
ModelOption.MAX_NEW_TOKENS:100,
35+
"stream":False,
36+
"homer_simpson":"option should be kicked out"
37+
}
38+
)
39+
assert res is not None
40+
assert isinstance(res.value, str)
41+
42+

0 commit comments

Comments
 (0)