Skip to content

Commit 793844c

Browse files
authored
fix: watsonx and litellm parameter filtering (#187)
* fix: watsonx param filter * fix: litellm model options filtering and tests * fix: change conftest to skip instead of fail qual tests on github * fix: remove comment * fix: test defaults * test: fixes to litellm test * test:fix test defaults
1 parent 9948907 commit 793844c

File tree

5 files changed

+126
-34
lines changed

5 files changed

+126
-34
lines changed

mellea/backends/litellm.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class LiteLLMBackend(FormatterBackend):
4848

4949
def __init__(
5050
self,
51-
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name),
51+
model_id: str = "ollama_chat/"
52+
+ str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name),
5253
formatter: Formatter | None = None,
5354
base_url: str | None = "http://localhost:11434",
5455
model_options: dict | None = None,
@@ -100,7 +101,7 @@ def __init__(
100101
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
101102
# Usually, values that are intentionally extracted while prepping for the backend generate call
102103
# will be omitted here so that they will be removed when model_options are processed
103-
# for the call to the model.
104+
# for the call to the model. For LiteLLM, this dict might change slightly depending on the provider.
104105
self.from_mellea_model_opts_map = {
105106
ModelOption.SEED: "seed",
106107
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
@@ -176,39 +177,57 @@ def _make_backend_specific_and_remove(
176177
Returns:
177178
a new dict
178179
"""
179-
backend_specific = ModelOption.replace_keys(
180-
model_options, self.from_mellea_model_opts_map
181-
)
182-
backend_specific = ModelOption.remove_special_keys(backend_specific)
183-
184180
# We set `drop_params=True` which will drop non-supported openai params; check for non-openai
185181
# params that might cause errors and log which openai params aren't supported here.
186182
# See https://docs.litellm.ai/docs/completion/input.
187-
# standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
188183
supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params(
189184
self._model_id
190185
)
191186
supported_params = (
192187
set(supported_params_list) if supported_params_list is not None else set()
193188
)
194189

195-
# unknown_keys = [] # keys that are unknown to litellm
196-
unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider
190+
# LiteLLM specific remappings (typically based on provider). There's a few cases where the provider accepts
191+
# different parameters than LiteLLM says it does. Here's a few rules that help in those scenarios.
192+
model_opts_remapping = self.from_mellea_model_opts_map.copy()
193+
if (
194+
"max_completion_tokens" not in supported_params
195+
and "max_tokens" in supported_params
196+
):
197+
# Scenario hit by Watsonx. LiteLLM believes Watsonx doesn't accept "max_completion_tokens" even though
198+
# OpenAI compatible endpoints should accept both (and Watsonx does accept both).
199+
model_opts_remapping[ModelOption.MAX_NEW_TOKENS] = "max_tokens"
200+
201+
backend_specific = ModelOption.replace_keys(model_options, model_opts_remapping)
202+
backend_specific = ModelOption.remove_special_keys(backend_specific)
203+
204+
# Since LiteLLM has many different providers, we add some additional parameter logging here.
205+
# There's two sets of parameters we have to look at:
206+
# - unsupported_openai_params: standard OpenAI parameters that LiteLLM will automatically drop for us when `drop_params=True` if the provider doesn't support them.
207+
# - unknown_keys: parameters that LiteLLM doesn't know about, aren't standard OpenAI parameters, and might be used by the provider. We don't drop these.
208+
# We want to flag both for the end user.
209+
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
210+
unknown_keys = [] # Keys that are unknown to litellm.
211+
unsupported_openai_params = [] # OpenAI params that are known to litellm but not supported for this model/provider.
197212
for key in backend_specific.keys():
198213
if key not in supported_params:
199-
unsupported_openai_params.append(key)
200-
201-
# if len(unknown_keys) > 0:
202-
# FancyLogger.get_logger().warning(
203-
# f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
204-
# )
214+
if key in standard_openai_subset:
215+
# LiteLLM is pretty confident that this standard OpenAI parameter won't work.
216+
unsupported_openai_params.append(key)
217+
else:
218+
# LiteLLM doesn't make any claims about this parameter; we won't drop it but we will keep track of it..
219+
unknown_keys.append(key)
220+
221+
if len(unknown_keys) > 0:
222+
FancyLogger.get_logger().warning(
223+
f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
224+
)
205225

206226
if len(unsupported_openai_params) > 0:
207227
FancyLogger.get_logger().warning(
208-
f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}"
228+
f"litellm may drop the following openai keys that it doesn't seem to recognize as being supported by the current model/provider: {', '.join(unsupported_openai_params)}"
229+
"\nThere are sometimes false positives here."
209230
)
210-
for key in unsupported_openai_params:
211-
del backend_specific[key]
212231

213232
return backend_specific
214233

mellea/backends/watsonx.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import warnings
99
from collections.abc import AsyncGenerator, Callable, Coroutine
10+
from dataclasses import fields
1011
from typing import Any
1112

1213
from ibm_watsonx_ai import APIClient, Credentials
@@ -110,7 +111,8 @@ def __init__(
110111
# These are usually values that must be extracted before hand or that are common among backend providers.
111112
self.to_mellea_model_opts_map_chats = {
112113
"system": ModelOption.SYSTEM_PROMPT,
113-
"max_tokens": ModelOption.MAX_NEW_TOKENS,
114+
"max_tokens": ModelOption.MAX_NEW_TOKENS, # Is being deprecated in favor of `max_completion_tokens.`
115+
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
114116
"tools": ModelOption.TOOLS,
115117
"stream": ModelOption.STREAM,
116118
}
@@ -120,7 +122,7 @@ def __init__(
120122
# will be omitted here so that they will be removed when model_options are processed
121123
# for the call to the model.
122124
self.from_mellea_model_opts_map_chats = {
123-
ModelOption.MAX_NEW_TOKENS: "max_tokens"
125+
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens"
124126
}
125127

126128
# See notes above.
@@ -168,7 +170,10 @@ def _get_watsonx_model_id(self) -> str:
168170

169171
def filter_chat_completions_kwargs(self, model_options: dict) -> dict:
170172
"""Filter kwargs to only include valid watsonx chat.completions.create parameters."""
171-
chat_params = TextChatParameters.get_sample_params().keys()
173+
# TextChatParameters.get_sample_params().keys() can't be completely trusted. It doesn't always contain all
174+
# all of the accepted keys. In version 1.3.39, max_tokens was removed even though it's still accepted.
175+
# It's a dataclass so use the fields function to get the names.
176+
chat_params = {field.name for field in fields(TextChatParameters)}
172177
return {k: v for k, v in model_options.items() if k in chat_params}
173178

174179
def _simplify_and_merge(

test/backends/test_litellm_ollama.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import os
23
import pytest
34

45
from mellea import MelleaSession, generative
@@ -8,14 +9,73 @@
89
from mellea.stdlib.chat import Message
910
from mellea.stdlib.sampling import RejectionSamplingStrategy
1011

12+
@pytest.fixture(scope="function")
13+
def backend(gh_run: int):
14+
"""Shared OpenAI backend configured for Ollama."""
15+
if gh_run == 1:
16+
# LiteLLM prepends 127.0.0.1 with a `/` which causes issues.
17+
url = os.environ.get("OLLAMA_HOST", None)
18+
if url is None:
19+
url = "http://localhost:11434"
20+
else:
21+
url = url.replace("127.0.0.1", "http://localhost")
22+
23+
return LiteLLMBackend(
24+
model_id="ollama_chat/llama3.2:1b",
25+
base_url=url,
26+
model_options={"api_base": url}
27+
)
28+
else:
29+
return LiteLLMBackend()
1130

1231
@pytest.fixture(scope="function")
13-
def session():
32+
def session(backend):
1433
"""Fresh Ollama session for each test."""
15-
session = MelleaSession(LiteLLMBackend())
34+
session = MelleaSession(backend=backend)
1635
yield session
1736
session.reset()
1837

38+
# Use capsys to check that the logging is working.
39+
def test_make_backend_specific_and_remove():
40+
# Doesn't need to be a real model here; just a provider that LiteLLM knows about.
41+
backend = LiteLLMBackend(model_id="ollama_chat/")
42+
43+
params = {
44+
"max_tokens": 1,
45+
"stream": 1,
46+
ModelOption.TEMPERATURE: 1,
47+
"unknown_parameter": 1, # Unknown / non-OpenAI parameter
48+
"web_search_options": 1, # Standard OpenAI parameter not supported by Ollama.
49+
}
50+
51+
mellea = backend._simplify_and_merge(params)
52+
backend_specific = backend._make_backend_specific_and_remove(mellea)
53+
54+
55+
# All of these options should be in the model options that get passed to LiteLLM since it handles the dropping.
56+
assert "max_completion_tokens" in backend_specific, "max_tokens should get remapped to max_completion_tokens for ollama_chat/"
57+
assert "stream" in backend_specific
58+
assert "temperature" in backend_specific
59+
assert "unknown_parameter" in backend_specific
60+
assert "web_search_options" in backend_specific
61+
62+
# TODO: Investigate why this isn't working on github action runners.
63+
# Add the capsys or capfd fixture back.
64+
# out = capsys.readouterr()
65+
# # Check for the specific warning logs.
66+
# assert "supported by the current model/provider: web_search_options" in out.out
67+
# assert "mellea won't validate the following params that may cause issues: unknown_parameter" in out.out
68+
69+
# Do a quick test for the Watsonx specific scenario.
70+
backend = LiteLLMBackend(model_id="watsonx/")
71+
watsonx_params = {"max_tokens": 1}
72+
73+
# Make sure we make it Mellea specific correctly.
74+
watsonx_mellea = backend._simplify_and_merge(watsonx_params)
75+
assert ModelOption.MAX_NEW_TOKENS in watsonx_mellea
76+
77+
watsonx_backend_specific = backend._make_backend_specific_and_remove(watsonx_mellea)
78+
assert "max_tokens" in watsonx_backend_specific
1979

2080
@pytest.mark.qualitative
2181
def test_litellm_ollama_chat(session):
@@ -26,7 +86,6 @@ def test_litellm_ollama_chat(session):
2686
f"Expected a message with content containing 2 but found {res}"
2787
)
2888

29-
@pytest.mark.qualitative
3089
def test_litellm_ollama_instruct(session):
3190
res = session.instruct(
3291
"Write an email to the interns.",
@@ -37,7 +96,6 @@ def test_litellm_ollama_instruct(session):
3796
assert isinstance(res.value, str)
3897

3998

40-
@pytest.mark.qualitative
4199
def test_litellm_ollama_instruct_options(session):
42100
model_options={
43101
ModelOption.SEED: 123,
@@ -59,11 +117,6 @@ def test_litellm_ollama_instruct_options(session):
59117
# make sure that homer_simpson is in the logged model_options
60118
assert "homer_simpson" in res._generate_log.model_options
61119

62-
# make sure the backend function filters out the model option when passing to the generate call
63-
backend = session.backend
64-
assert isinstance(backend, LiteLLMBackend)
65-
assert "homer_simpson" not in backend._make_backend_specific_and_remove(model_options)
66-
67120

68121
@pytest.mark.qualitative
69122
def test_gen_slot(session):
@@ -77,7 +130,6 @@ def is_happy(text: str) -> bool:
77130
# should yield to true - but, of course, is model dependent
78131
assert h is True
79132

80-
@pytest.mark.qualitative
81133
async def test_async_parallel_requests(session):
82134
model_opts = {ModelOption.STREAM: True}
83135
mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts)
@@ -104,7 +156,6 @@ async def test_async_parallel_requests(session):
104156
assert m1_final_val == mot1.value
105157
assert m2_final_val == mot2.value
106158

107-
@pytest.mark.qualitative
108159
async def test_async_avalue(session):
109160
mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext())
110161
m1_final_val = await mot1.avalue()

test/backends/test_watsonx.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ def session(backend: WatsonxAIBackend):
3434
yield session
3535
session.reset()
3636

37+
@pytest.mark.qualitative
38+
def test_filter_chat_completions_kwargs(backend: WatsonxAIBackend):
39+
"""Detect changes to the WatsonxAI TextChatParameters."""
40+
41+
known_keys = ['frequency_penalty', 'logprobs', 'top_logprobs', 'presence_penalty', 'response_format', 'temperature', 'max_tokens', 'max_completion_tokens', 'time_limit', 'top_p', 'n', 'logit_bias', 'seed', 'stop', 'guided_choice', 'guided_regex', 'guided_grammar', 'guided_json']
42+
test_dict = {key: 1 for key in known_keys}
43+
44+
# Make sure keys that we think should be in the TextChatParameters are there.
45+
filtered_dict = backend.filter_chat_completions_kwargs(test_dict)
46+
47+
for key in known_keys:
48+
assert key in filtered_dict
49+
50+
# Make sure unsupported keys still get filtered out.
51+
incorrect_dict = {"random": 1}
52+
filtered_incorrect_dict = backend.filter_chat_completions_kwargs(incorrect_dict)
53+
assert "random" not in filtered_incorrect_dict
3754

3855
@pytest.mark.qualitative
3956
def test_instruct(session: MelleaSession):

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ def pytest_runtest_setup(item):
2121
gh_run = int(os.environ.get("CICD", 0))
2222

2323
if gh_run == 1:
24-
pytest.xfail(
24+
pytest.skip(
2525
reason="Skipping qualitative test: got env variable CICD == 1. Used only in gh workflows."
2626
)

0 commit comments

Comments
 (0)