Skip to content

Commit 5f14ff7

Browse files
Merge pull request #62 from generative-computing/jal/litellm_310_changes
model options changes for litellm
2 parents 9310eaa + ee02bc7 commit 5f14ff7

File tree

1 file changed

+116
-25
lines changed

1 file changed

+116
-25
lines changed

mellea/backends/litellm.py

Lines changed: 116 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import datetime
44
import json
55
from collections.abc import Callable
6+
from typing import Any
67

78
import litellm
9+
import litellm.litellm_core_utils
10+
import litellm.litellm_core_utils.get_supported_openai_params
811

912
import mellea.backends.model_ids as model_ids
1013
from mellea.backends import BaseModelSubclass
@@ -61,6 +64,31 @@ def __init__(
6164
else:
6265
self._base_url = base_url
6366

67+
# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
68+
# These are usually values that must be extracted before hand or that are common among backend providers.
69+
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
70+
# users should only be specifying a single one in their request.
71+
self.to_mellea_model_opts_map = {
72+
"system": ModelOption.SYSTEM_PROMPT,
73+
"reasoning_effort": ModelOption.THINKING, # TODO: JAL; see which of these are actually extracted...
74+
"seed": ModelOption.SEED,
75+
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
76+
"max_tokens": ModelOption.MAX_NEW_TOKENS,
77+
"tools": ModelOption.TOOLS,
78+
"functions": ModelOption.TOOLS,
79+
}
80+
81+
# A mapping of Mellea specific ModelOptions to the specific names for this backend.
82+
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
83+
# Usually, values that are intentionally extracted while prepping for the backend generate call
84+
# will be omitted here so that they will be removed when model_options are processed
85+
# for the call to the model.
86+
self.from_mellea_model_opts_map = {
87+
ModelOption.SEED: "seed",
88+
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
89+
ModelOption.THINKING: "reasoning_effort",
90+
}
91+
6492
def generate_from_context(
6593
self,
6694
action: Component | CBlock,
@@ -84,30 +112,87 @@ def generate_from_context(
84112
tool_calls=tool_calls,
85113
)
86114

87-
def _simplify_and_merge(self, mo: dict) -> dict:
88-
mo_safe = {} if mo is None else mo.copy()
89-
mo_merged = ModelOption.merge_model_options(self.model_options, mo_safe)
115+
def _simplify_and_merge(
116+
self, model_options: dict[str, Any] | None
117+
) -> dict[str, Any]:
118+
"""Simplifies model_options to use the Mellea specific ModelOption.Option and merges the backend's model_options with those passed into this call.
90119
91-
# map to valid litellm names
92-
mo_mapping = {
93-
ModelOption.TOOLS: "tools",
94-
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
95-
ModelOption.SEED: "seed",
96-
ModelOption.THINKING: "thinking",
97-
}
98-
mo_res = ModelOption.replace_keys(mo_merged, mo_mapping)
99-
mo_res = ModelOption.remove_special_keys(mo_res)
100-
101-
supported_params = litellm.get_supported_openai_params(self._model_id)
102-
assert supported_params is not None
103-
for k in list(mo_res.keys()):
104-
if k not in supported_params:
105-
del mo_res[k]
106-
FancyLogger.get_logger().warn(
107-
f"Skipping '{k}' -- Model-Option not supported by {self.model_id}."
108-
)
120+
Rules:
121+
- Within a model_options dict, existing keys take precedence. This means remapping to mellea specific keys will maintain the value of the mellea specific key if one already exists.
122+
- When merging, the keys/values from the dictionary passed into this function take precedence.
123+
124+
Because this function simplifies and then merges, non-Mellea keys from the passed in model_options will replace
125+
Mellea specific keys from the backend's model_options.
109126
110-
return mo_res
127+
Args:
128+
model_options: the model_options for this call
129+
130+
Returns:
131+
a new dict
132+
"""
133+
backend_model_opts = ModelOption.replace_keys(
134+
self.model_options, self.to_mellea_model_opts_map
135+
)
136+
137+
if model_options is None:
138+
return backend_model_opts
139+
140+
generate_call_model_opts = ModelOption.replace_keys(
141+
model_options, self.to_mellea_model_opts_map
142+
)
143+
return ModelOption.merge_model_options(
144+
backend_model_opts, generate_call_model_opts
145+
)
146+
147+
def _make_backend_specific_and_remove(
148+
self, model_options: dict[str, Any]
149+
) -> dict[str, Any]:
150+
"""Maps specified Mellea specific keys to their backend specific version and removes any remaining Mellea keys.
151+
152+
Additionally, logs any params unknown to litellm and any params that are openai specific but not supported by this model/provider.
153+
154+
Args:
155+
model_options: the model_options for this call
156+
157+
Returns:
158+
a new dict
159+
"""
160+
backend_specific = ModelOption.replace_keys(
161+
model_options, self.from_mellea_model_opts_map
162+
)
163+
backend_specific = ModelOption.remove_special_keys(backend_specific)
164+
165+
# We set `drop_params=True` which will drop non-supported openai params; check for non-openai
166+
# params that might cause errors and log which openai params aren't supported here.
167+
# See https://docs.litellm.ai/docs/completion/input.
168+
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
169+
supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params(
170+
self._model_id
171+
)
172+
supported_params = (
173+
set(supported_params_list) if supported_params_list is not None else set()
174+
)
175+
176+
unknown_keys = [] # keys that are unknown to litellm
177+
unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider
178+
for key in backend_specific.keys():
179+
if key not in standard_openai_subset.keys():
180+
unknown_keys.append(key)
181+
182+
elif key not in supported_params:
183+
unsupported_openai_params.append(key)
184+
185+
if len(unknown_keys) > 0:
186+
FancyLogger.get_logger().warning(
187+
f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
188+
)
189+
190+
if len(unsupported_openai_params) > 0:
191+
FancyLogger.get_logger().warning(
192+
f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}"
193+
)
194+
195+
return backend_specific
111196

112197
def _generate_from_chat_context_standard(
113198
self,
@@ -120,7 +205,6 @@ def _generate_from_chat_context_standard(
120205
generate_logs: list[GenerateLog] | None = None,
121206
tool_calls: bool = False,
122207
) -> ModelOutputThunk:
123-
model_options = {} if model_options is None else model_options
124208
model_opts = self._simplify_and_merge(model_options)
125209
linearized_context = ctx.linearize()
126210
assert linearized_context is not None, (
@@ -136,7 +220,7 @@ def _generate_from_chat_context_standard(
136220
messages.extend(self.formatter.to_chat_messages([action]))
137221

138222
conversation: list[dict] = []
139-
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, "")
223+
system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "")
140224
if system_prompt != "":
141225
conversation.append({"role": "system", "content": system_prompt})
142226
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
@@ -153,6 +237,11 @@ def _generate_from_chat_context_standard(
153237
else:
154238
response_format = {"type": "text"}
155239

240+
thinking = model_opts.get(ModelOption.THINKING, None)
241+
if type(thinking) is bool and thinking:
242+
# OpenAI uses strings for its reasoning levels.
243+
thinking = "medium"
244+
156245
# Append tool call information if applicable.
157246
tools = self._extract_tools(action, format, model_opts, tool_calls)
158247
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
@@ -162,7 +251,9 @@ def _generate_from_chat_context_standard(
162251
messages=conversation,
163252
tools=formatted_tools,
164253
response_format=response_format,
165-
**model_opts,
254+
reasoning_effort=thinking, # type: ignore
255+
drop_params=True, # See note in `_make_backend_specific_and_remove`.
256+
**self._make_backend_specific_and_remove(model_opts),
166257
)
167258

168259
choice_0 = chat_response.choices[0]

0 commit comments

Comments
 (0)