Skip to content

Commit 789baa0

Browse files
avinash2692nrfultonjakelorocco
authored andcommitted
fix: OpenAI base_url default and reasoning effort model option. (generative-computing#271)
* Adding a fix to pass `reasoning_effort` in conditionally * adding tests * Fixes generative-computing#274 * Adds GPT 5.1 model identifier. * Changes OpenAI Backend default model_id to GPT 5.1. This default is changed because the default base_url is also changed. * Fixes bug: GenSlots did not work with OpenAI platform. The OpenAI response_format only accepts a limited set of schemas and will error out with a 400 if you do not follow their guidelines. One of these guidelines is that additionalProperties is set and is set to False. This commit monkey-patches the response_format provided to OpenAI platform backends, and leaves other OpenAI-"compatible" backends with the existing default behavior. This is a debatable choice. See https://community.openai.com/t/api-rejects-valid-json-schema/906163 and https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat * Adds inline documentation for OpenAI model options monkey patching. * removes debug print stmt. * adding a comment about reasoning_effort in openai sdk * removing all instances of hf_model_id in openai backend * removing apply_chat_template and adding assertions for env variable * adding some tests for param checking * changing env variable handling logic. * base_url check is now a warning * fix: change warning message in openai.py * marking test as qualitative cause it's causing timeouts in github actions(I think) --------- Co-authored-by: Nathan Fulton <gitcommit@nfulton.org> Co-authored-by: jakelorocco <59755218+jakelorocco@users.noreply.github.com>
1 parent ff4e5bc commit 789baa0

File tree

4 files changed

+159
-49
lines changed

4 files changed

+159
-49
lines changed

mellea/backends/model_ids.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ModelIdentifier:
1717
ollama_name: str | None = None
1818
watsonx_name: str | None = None
1919
mlx_name: str | None = None
20+
openai_name: str | None = None
2021

2122
hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name
2223

@@ -134,9 +135,9 @@ class ModelIdentifier:
134135

135136
QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")
136137

137-
######################
138-
#### OpenAI models ###
139-
######################
138+
###########################
139+
#### OpenAI open models ###
140+
###########################
140141

141142
OPENAI_GPT_OSS_20B = ModelIdentifier(
142143
hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b"
@@ -145,6 +146,12 @@ class ModelIdentifier:
145146
hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b"
146147
)
147148

149+
###########################
150+
#### OpenAI prop models ###
151+
###########################
152+
153+
OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1")
154+
148155
#####################
149156
#### Misc models ####
150157
#####################

mellea/backends/openai.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
import inspect
88
import json
9+
import os
910
from collections.abc import Callable, Coroutine
1011
from copy import deepcopy
1112
from enum import Enum
@@ -72,7 +73,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin):
7273

7374
def __init__(
7475
self,
75-
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
76+
model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1,
7677
formatter: Formatter | None = None,
7778
base_url: str | None = None,
7879
model_options: dict | None = None,
@@ -142,26 +143,38 @@ def __init__(
142143

143144
self.default_to_constraint_checking_alora = default_to_constraint_checking_alora
144145

145-
self._model_id = model_id
146146
match model_id:
147147
case str():
148-
self._hf_model_id = model_id
148+
self._model_id = model_id
149149
case ModelIdentifier():
150-
assert model_id.hf_model_name is not None, (
151-
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
150+
assert model_id.openai_name is not None, (
151+
"model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set."
152152
)
153-
self._hf_model_id = model_id.hf_model_name
153+
self._model_id = model_id.openai_name
154154

155-
if base_url is None:
156-
self._base_url = "http://localhost:11434/v1" # ollama
157-
else:
158-
self._base_url = base_url
159-
if api_key is None:
160-
self._api_key = "ollama"
161-
else:
162-
self._api_key = api_key
155+
# Use provided parameters or fall back to environment variables
156+
self._api_key = api_key
157+
self._base_url = base_url
163158

164-
self._server_type = _server_type(self._base_url)
159+
# Validate that we have the required configuration
160+
if self._api_key is None and os.getenv("OPENAI_API_KEY") is None:
161+
raise ValueError(
162+
"OPENAI_API_KEY or api_key is required but not set. Please either:\n"
163+
" 1. Set the environment variable: export OPENAI_API_KEY='your-key-here'\n"
164+
" 2. Pass it as a parameter: OpenAIBackend(api_key='your-key-here')"
165+
)
166+
167+
if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None:
168+
FancyLogger.get_logger().warning(
169+
"OPENAI_BASE_URL or base_url is not set.\n"
170+
"The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`"
171+
)
172+
173+
self._server_type: _ServerType = (
174+
_server_type(self._base_url)
175+
if self._base_url is not None
176+
else _ServerType.OPENAI
177+
) # type: ignore
165178

166179
self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
167180

@@ -668,14 +681,38 @@ async def _generate_from_chat_context_standard(
668681

669682
extra_params: dict[str, Any] = {}
670683
if _format is not None:
671-
extra_params["response_format"] = {
672-
"type": "json_schema",
673-
"json_schema": {
674-
"name": _format.__name__,
675-
"schema": _format.model_json_schema(),
676-
"strict": True,
677-
},
678-
}
684+
if self._server_type == _ServerType.OPENAI:
685+
# The OpenAI platform requires that additionalProperties=False on all response_format schemas.
686+
# However, not all schemas generates by Mellea include additionalProperties.
687+
# GenerativeSlot, in particular, does not add this property.
688+
# The easiest way to address this disparity between OpenAI and other inference providers is to
689+
# monkey-patch the response format exactly when we are actually using the OpenAI server.
690+
#
691+
# This only addresses the additionalProperties=False constraint.
692+
# Other constraints we should be checking/patching are described here:
693+
# https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat
694+
monkey_patched_response_schema = _format.model_json_schema()
695+
monkey_patched_response_schema["additionalProperties"] = False
696+
extra_params["response_format"] = {
697+
"type": "json_schema",
698+
"json_schema": {
699+
"name": _format.__name__,
700+
"schema": monkey_patched_response_schema,
701+
"strict": True,
702+
},
703+
}
704+
else:
705+
FancyLogger().get_logger().warning(
706+
"Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider."
707+
)
708+
extra_params["response_format"] = {
709+
"type": "json_schema",
710+
"json_schema": {
711+
"name": _format.__name__,
712+
"schema": _format.model_json_schema(),
713+
"strict": True,
714+
},
715+
}
679716

680717
# Append tool call information if applicable.
681718
tools: dict[str, Callable] = dict()
@@ -701,15 +738,21 @@ async def _generate_from_chat_context_standard(
701738
formatted_tools = convert_tools_to_json(tools)
702739
use_tools = len(formatted_tools) > 0
703740

741+
# Build optional reasoning parameters
742+
# NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o
743+
reasoning_params = {}
744+
if thinking is not None:
745+
reasoning_params["reasoning_effort"] = thinking
746+
704747
chat_response: Coroutine[
705748
Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk]
706749
] = self._async_client.chat.completions.create(
707-
model=self._hf_model_id,
750+
model=self._model_id,
708751
messages=conversation, # type: ignore
709-
reasoning_effort=thinking, # type: ignore
710752
tools=formatted_tools if use_tools else None, # type: ignore
711753
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
712754
**extra_params,
755+
**reasoning_params, # type: ignore
713756
**self._make_backend_specific_and_remove(
714757
model_opts, is_chat_context=ctx.is_chat_context
715758
),
@@ -877,7 +920,7 @@ async def generate_from_raw(
877920
try:
878921
completion_response: Completion = (
879922
await self._async_client.completions.create(
880-
model=self._hf_model_id,
923+
model=self._model_id,
881924
prompt=prompts,
882925
extra_body=extra_body,
883926
**self._make_backend_specific_and_remove(
@@ -930,7 +973,10 @@ async def generate_from_raw(
930973
@property
931974
def base_model_name(self):
932975
"""Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`."""
933-
return self._hf_model_id.split("/")[1]
976+
if "/" in self._model_id:
977+
return self._model_id.split("/")[1]
978+
else:
979+
return self._model_id
934980

935981
def add_adapter(self, adapter: OpenAIAdapter):
936982
"""Adds the given adapter to the backend. Must not have been added to a different backend."""
@@ -1085,22 +1131,3 @@ def list_adapters(self) -> list[str]:
10851131
:returns: list of adapter names that are currently registered with this backend
10861132
"""
10871133
return list(self._loaded_adapters.keys())
1088-
1089-
def apply_chat_template(self, chat: list[dict[str, str]]):
1090-
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
1091-
from transformers import AutoTokenizer
1092-
1093-
if not hasattr(self, "_tokenizer"):
1094-
match _server_type(self._base_url):
1095-
case _ServerType.LOCALHOST:
1096-
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
1097-
AutoTokenizer.from_pretrained(self._hf_model_id)
1098-
)
1099-
case _ServerType.OPENAI:
1100-
raise Exception(
1101-
"apply_chat_template is called while targeting a server at openai.com. "
1102-
"This is not supported --- openai.com does not support Activated Lora. "
1103-
"Use a locally served vllm instance. "
1104-
)
1105-
1106-
return self._tokenizer.apply_chat_template(chat, tokenize=False)

test/backends/test_openai_ollama.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# test/rits_backend_tests/test_openai_integration.py
22
import asyncio
33
import os
4+
from unittest.mock import patch
45

56
import openai
67
import pydantic
@@ -216,6 +217,80 @@ async def get_client_async():
216217
assert len(backend._client_cache.cache.values()) == 2
217218

218219

220+
async def test_reasoning_effort_conditional_passing(backend):
221+
"""Test that reasoning_effort is only passed to API when not None."""
222+
from unittest.mock import AsyncMock, MagicMock, patch
223+
224+
ctx = ChatContext()
225+
ctx = ctx.add(CBlock(value="Test"))
226+
227+
mock_response = MagicMock()
228+
mock_response.choices = [MagicMock()]
229+
mock_response.choices[0].message = MagicMock()
230+
mock_response.choices[0].message.content = "Response"
231+
mock_response.choices[0].message.role = "assistant"
232+
233+
# Test 1: reasoning_effort should NOT be passed when not specified
234+
with patch.object(
235+
backend._async_client.chat.completions, "create", new_callable=AsyncMock
236+
) as mock_create:
237+
mock_create.return_value = mock_response
238+
await backend.generate_from_chat_context(
239+
CBlock(value="Hi"), ctx, model_options={}
240+
)
241+
call_kwargs = mock_create.call_args.kwargs
242+
assert "reasoning_effort" not in call_kwargs, (
243+
"reasoning_effort should not be passed when not specified"
244+
)
245+
246+
# Test 2: reasoning_effort SHOULD be passed when specified
247+
with patch.object(
248+
backend._async_client.chat.completions, "create", new_callable=AsyncMock
249+
) as mock_create:
250+
mock_create.return_value = mock_response
251+
await backend.generate_from_chat_context(
252+
CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"}
253+
)
254+
call_kwargs = mock_create.call_args.kwargs
255+
assert call_kwargs.get("reasoning_effort") == "medium", (
256+
"reasoning_effort should be passed with correct value when specified"
257+
)
258+
259+
260+
def test_api_key_and_base_url_from_parameters():
261+
"""Test that API key and base URL can be set via parameters."""
262+
backend = OpenAIBackend(
263+
model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1"
264+
)
265+
assert backend._api_key == "test-api-key"
266+
assert backend._base_url == "https://api.test.com/v1"
267+
268+
269+
def test_parameter_overrides_env_variable():
270+
"""Test that explicit parameters override environment variables."""
271+
with patch.dict(
272+
os.environ,
273+
{"OPENAI_API_KEY": "env-api-key", "OPENAI_BASE_URL": "https://api.env.com/v1"},
274+
):
275+
backend = OpenAIBackend(
276+
model_id="gpt-4",
277+
api_key="param-api-key",
278+
base_url="https://api.param.com/v1",
279+
)
280+
assert backend._api_key == "param-api-key"
281+
assert backend._base_url == "https://api.param.com/v1"
282+
283+
284+
def test_missing_api_key_raises_error():
285+
"""Test that missing API key raises ValueError with helpful message."""
286+
with patch.dict(os.environ, {}, clear=True):
287+
with pytest.raises(ValueError) as exc_info:
288+
OpenAIBackend(model_id="gpt-4", base_url="https://api.test.com/v1")
289+
assert "OPENAI_API_KEY or api_key is required but not set" in str(
290+
exc_info.value
291+
)
292+
293+
219294
if __name__ == "__main__":
220295
import pytest
221296

test/stdlib_intrinsics/test_rag/test_rag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_answer_relevance(backend):
184184
assert result == answer
185185

186186

187+
@pytest.mark.qualitative
187188
def test_answer_relevance_classifier(backend):
188189
"""Verify that the first phase of the answer relevance flow behaves as expectee."""
189190
context, answer, docs = _read_input_json("answer_relevance.json")

0 commit comments

Comments
 (0)