Skip to content

Commit 08f9ebc

Browse files
using new utility functions
fixing model option cleanup
1 parent 5265ab6 commit 08f9ebc

File tree

2 files changed

+47
-110
lines changed

2 files changed

+47
-110
lines changed

mellea/backends/litellm.py

Lines changed: 23 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
import mellea.backends.model_ids as model_ids
1313
from mellea.backends import BaseModelSubclass
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
15-
from mellea.backends.tools import convert_tools_to_json, get_tools_from_action
15+
from mellea.backends.tools import (
16+
add_tools_from_context_actions,
17+
add_tools_from_model_options,
18+
convert_tools_to_json,
19+
)
1620
from mellea.backends.types import ModelOption
1721
from mellea.helpers.fancy_logger import FancyLogger
1822
from mellea.stdlib.base import (
@@ -22,10 +26,9 @@
2226
GenerateLog,
2327
ModelOutputThunk,
2428
ModelToolCall,
25-
TemplateRepresentation,
2629
)
2730
from mellea.stdlib.chat import Message
28-
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
31+
from mellea.stdlib.requirement import ALoraRequirement
2932

3033

3134
class LiteLLMBackend(FormatterBackend):
@@ -86,7 +89,6 @@ def __init__(
8689
self.from_mellea_model_opts_map = {
8790
ModelOption.SEED: "seed",
8891
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
89-
ModelOption.THINKING: "reasoning_effort",
9092
}
9193

9294
def generate_from_context(
@@ -165,32 +167,31 @@ def _make_backend_specific_and_remove(
165167
# We set `drop_params=True` which will drop non-supported openai params; check for non-openai
166168
# params that might cause errors and log which openai params aren't supported here.
167169
# See https://docs.litellm.ai/docs/completion/input.
168-
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
170+
# standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
169171
supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params(
170172
self._model_id
171173
)
172174
supported_params = (
173175
set(supported_params_list) if supported_params_list is not None else set()
174176
)
175177

176-
unknown_keys = [] # keys that are unknown to litellm
178+
# unknown_keys = [] # keys that are unknown to litellm
177179
unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider
178180
for key in backend_specific.keys():
179-
if key not in standard_openai_subset.keys():
180-
unknown_keys.append(key)
181-
182-
elif key not in supported_params:
181+
if key not in supported_params:
183182
unsupported_openai_params.append(key)
184183

185-
if len(unknown_keys) > 0:
186-
FancyLogger.get_logger().warning(
187-
f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
188-
)
184+
# if len(unknown_keys) > 0:
185+
# FancyLogger.get_logger().warning(
186+
# f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
187+
# )
189188

190189
if len(unsupported_openai_params) > 0:
191190
FancyLogger.get_logger().warning(
192191
f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}"
193192
)
193+
for key in unsupported_openai_params:
194+
del backend_specific[key]
194195

195196
return backend_specific
196197

@@ -206,7 +207,7 @@ def _generate_from_chat_context_standard(
206207
tool_calls: bool = False,
207208
) -> ModelOutputThunk:
208209
model_opts = self._simplify_and_merge(model_options)
209-
linearized_context = ctx.linearize()
210+
linearized_context = ctx.render_for_generation()
210211
assert linearized_context is not None, (
211212
"Cannot generate from a non-linear context in a FormatterBackend."
212213
)
@@ -246,14 +247,16 @@ def _generate_from_chat_context_standard(
246247
tools = self._extract_tools(action, format, model_opts, tool_calls)
247248
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
248249

250+
model_specific_options = self._make_backend_specific_and_remove(model_opts)
251+
249252
chat_response: litellm.ModelResponse = litellm.completion(
250253
model=self._model_id,
251254
messages=conversation,
252255
tools=formatted_tools,
253256
response_format=response_format,
254257
reasoning_effort=thinking, # type: ignore
255258
drop_params=True, # See note in `_make_backend_specific_and_remove`.
256-
**self._make_backend_specific_and_remove(model_opts),
259+
**model_specific_options,
257260
)
258261

259262
choice_0 = chat_response.choices[0]
@@ -275,7 +278,7 @@ def _generate_from_chat_context_standard(
275278
generate_log = GenerateLog()
276279
generate_log.prompt = conversation
277280
generate_log.backend = f"litellm::{self.model_id!s}"
278-
generate_log.model_options = model_opts
281+
generate_log.model_options = model_specific_options
279282
generate_log.date = datetime.datetime.now()
280283
generate_log.model_output = chat_response
281284
generate_log.extra = {
@@ -291,36 +294,16 @@ def _generate_from_chat_context_standard(
291294
return parsed_result
292295

293296
@staticmethod
294-
def _extract_tools(action, format, model_opts, tool_calls):
297+
def _extract_tools(action, format, model_opts, tool_calls) -> dict[str, Callable]:
295298
tools: dict[str, Callable] = dict()
296299
if tool_calls:
297300
if format:
298301
FancyLogger.get_logger().warning(
299302
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
300303
)
301304
else:
302-
if isinstance(action, Component) and isinstance(
303-
action.format_for_llm(), TemplateRepresentation
304-
):
305-
tools = get_tools_from_action(action)
306-
307-
model_options_tools = model_opts.get(ModelOption.TOOLS, None)
308-
if model_options_tools is not None:
309-
assert isinstance(model_options_tools, dict)
310-
for fn_name in model_options_tools:
311-
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
312-
assert fn_name not in tools.keys(), (
313-
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
314-
)
315-
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
316-
assert type(fn_name) is str, (
317-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
318-
)
319-
assert callable(model_options_tools[fn_name]), (
320-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
321-
)
322-
# Add the model_options tool to the existing set of tools.
323-
tools[fn_name] = model_options_tools[fn_name]
305+
add_tools_from_context_actions(tools, [action])
306+
add_tools_from_model_options(tools, model_opts)
324307
return tools
325308

326309
def _generate_from_raw(
@@ -333,68 +316,6 @@ def _generate_from_raw(
333316
) -> list[ModelOutputThunk]:
334317
"""Generate using the completions api. Gives the input provided to the model without templating."""
335318
raise NotImplementedError("This method is not implemented yet.")
336-
# extra_body = {}
337-
# if format is not None:
338-
# FancyLogger.get_logger().warning(
339-
# "The official OpenAI completion api does not accept response format / structured decoding; "
340-
# "it will be passed as an extra arg."
341-
# )
342-
#
343-
# # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests.
344-
# extra_body["guided_json"] = format.model_json_schema()
345-
#
346-
# model_opts = self._simplify_and_merge(model_options, is_chat_context=False)
347-
#
348-
# prompts = [self.formatter.print(action) for action in actions]
349-
#
350-
# try:
351-
# completion_response: Completion = self._client.completions.create(
352-
# model=self._hf_model_id,
353-
# prompt=prompts,
354-
# extra_body=extra_body,
355-
# **self._make_backend_specific_and_remove(
356-
# model_opts, is_chat_context=False
357-
# ),
358-
# ) # type: ignore
359-
# except openai.BadRequestError as e:
360-
# if openai_ollama_batching_error in e.message:
361-
# FancyLogger.get_logger().error(
362-
# "If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
363-
# "your requests will fail since ollama doesn't support batching requests."
364-
# )
365-
# raise e
366-
#
367-
# # Necessary for type checker.
368-
# assert isinstance(completion_response, Completion)
369-
#
370-
# results = [
371-
# ModelOutputThunk(
372-
# value=response.text,
373-
# meta={"oai_completion_response": response.model_dump()},
374-
# )
375-
# for response in completion_response.choices
376-
# ]
377-
#
378-
# for i, result in enumerate(results):
379-
# self.formatter.parse(actions[i], result)
380-
#
381-
# if generate_logs is not None:
382-
# assert isinstance(generate_logs, list)
383-
# date = datetime.datetime.now()
384-
#
385-
# for i in range(len(prompts)):
386-
# generate_log = GenerateLog()
387-
# generate_log.prompt = prompts[i]
388-
# generate_log.backend = f"openai::{self.model_id!s}"
389-
# generate_log.model_options = model_opts
390-
# generate_log.date = date
391-
# generate_log.model_output = completion_response
392-
# generate_log.extra = {"seed": model_opts.get("seed", None)}
393-
# generate_log.action = actions[i]
394-
# generate_log.result = results[i]
395-
# generate_logs.append(generate_log)
396-
#
397-
# return results
398319

399320
def _extract_model_tool_requests(
400321
self, tools: dict[str, Callable], chat_response: litellm.ModelResponse

test/backends/test_litellm_ollama.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import mellea
2-
from mellea import MelleaSession
2+
from mellea import MelleaSession, generative
33
from mellea.backends import ModelOption
44
from mellea.backends.litellm import LiteLLMBackend
55
from mellea.stdlib.chat import Message
@@ -18,7 +18,7 @@ def test_litellm_ollama_instruct(self):
1818
res = self.m.instruct(
1919
"Write an email to the interns.",
2020
requirements=["be funny"],
21-
strategy=RejectionSamplingStrategy(loop_budget=3)
21+
strategy=RejectionSamplingStrategy(loop_budget=3),
2222
)
2323
assert res is not None
2424
assert isinstance(res.value, str)
@@ -29,14 +29,30 @@ def test_litellm_ollama_instruct_options(self):
2929
requirements=["be funny"],
3030
model_options={
3131
ModelOption.SEED: 123,
32-
ModelOption.TEMPERATURE: .5,
33-
ModelOption.THINKING:True,
34-
ModelOption.MAX_NEW_TOKENS:100,
35-
"stream":False,
36-
"homer_simpson":"option should be kicked out"
37-
}
32+
ModelOption.TEMPERATURE: 0.5,
33+
ModelOption.THINKING: True,
34+
ModelOption.MAX_NEW_TOKENS: 100,
35+
"reasoning_effort":True,
36+
"stream": False,
37+
"homer_simpson": "option should be kicked out",
38+
},
3839
)
3940
assert res is not None
4041
assert isinstance(res.value, str)
42+
assert "homer_simpson" not in self.m.ctx.last_output_and_logs()[1].model_options
4143

44+
def test_gen_slot(self):
45+
@generative
46+
def is_happy(text: str) -> bool:
47+
"""Determine if text is of happy mood."""
4248

49+
h = is_happy(self.m, text="I'm enjoying life.")
50+
51+
assert isinstance(h, bool)
52+
assert h is True
53+
54+
55+
if __name__ == "__main__":
56+
import pytest
57+
58+
pytest.main([__file__])

0 commit comments

Comments
 (0)