Skip to content

Commit 9733df8

Browse files
avinash2692nrfultonjakelorocco
authored
fix: OpenAI base_url default and reasoning effort model option. (#271)
* Adding a fix to pass `reasoning_effort` in conditionally * adding tests * Fixes #274 * Adds GPT 5.1 model identifier. * Changes OpenAI Backend default model_id to GPT 5.1. This default is changed because the default base_url is also changed. * Fixes bug: GenSlots did not work with OpenAI platform. The OpenAI response_format only accepts a limited set of schemas and will error out with a 400 if you do not follow their guidelines. One of these guidelines is that additionalProperties is set and is set to False. This commit monkey-patches the response_format provided to OpenAI platform backends, and leaves other OpenAI-"compatible" backends with the existing default behavior. This is a debatable choice. See https://community.openai.com/t/api-rejects-valid-json-schema/906163 and https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat * Adds inline documentation for OpenAI model options monkey patching. * removes debug print stmt. * adding a comment about reasoning_effort in openai sdk * removing all instances of hf_model_id in openai backend * removing apply_chat_template and adding assertions for env variable * adding some tests for param checking * changing env variable handling logic. * base_url check is now a warning * fix: change warning message in openai.py * marking test as qualitative cause it's causing timeouts in github actions(I think) --------- Co-authored-by: Nathan Fulton <[email protected]> Co-authored-by: jakelorocco <[email protected]>
1 parent 0b402bd commit 9733df8

File tree

4 files changed

+159
-49
lines changed

4 files changed

+159
-49
lines changed

mellea/backends/model_ids.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ModelIdentifier:
1717
ollama_name: str | None = None
1818
watsonx_name: str | None = None
1919
mlx_name: str | None = None
20+
openai_name: str | None = None
2021

2122
hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name
2223

@@ -134,9 +135,9 @@ class ModelIdentifier:
134135

135136
QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")
136137

137-
######################
138-
#### OpenAI models ###
139-
######################
138+
###########################
139+
#### OpenAI open models ###
140+
###########################
140141

141142
OPENAI_GPT_OSS_20B = ModelIdentifier(
142143
hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b"
@@ -145,6 +146,12 @@ class ModelIdentifier:
145146
hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b"
146147
)
147148

149+
###########################
150+
#### OpenAI prop models ###
151+
###########################
152+
153+
OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1")
154+
148155
#####################
149156
#### Misc models ####
150157
#####################

mellea/backends/openai.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
import inspect
88
import json
9+
import os
910
from collections.abc import Callable, Coroutine
1011
from copy import deepcopy
1112
from enum import Enum
@@ -72,7 +73,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin):
7273

7374
def __init__(
7475
self,
75-
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
76+
model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1,
7677
formatter: Formatter | None = None,
7778
base_url: str | None = None,
7879
model_options: dict | None = None,
@@ -142,26 +143,38 @@ def __init__(
142143

143144
self.default_to_constraint_checking_alora = default_to_constraint_checking_alora
144145

145-
self._model_id = model_id
146146
match model_id:
147147
case str():
148-
self._hf_model_id = model_id
148+
self._model_id = model_id
149149
case ModelIdentifier():
150-
assert model_id.hf_model_name is not None, (
151-
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
150+
assert model_id.openai_name is not None, (
151+
"model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set."
152152
)
153-
self._hf_model_id = model_id.hf_model_name
153+
self._model_id = model_id.openai_name
154154

155-
if base_url is None:
156-
self._base_url = "http://localhost:11434/v1" # ollama
157-
else:
158-
self._base_url = base_url
159-
if api_key is None:
160-
self._api_key = "ollama"
161-
else:
162-
self._api_key = api_key
155+
# Use provided parameters or fall back to environment variables
156+
self._api_key = api_key
157+
self._base_url = base_url
163158

164-
self._server_type = _server_type(self._base_url)
159+
# Validate that we have the required configuration
160+
if self._api_key is None and os.getenv("OPENAI_API_KEY") is None:
161+
raise ValueError(
162+
"OPENAI_API_KEY or api_key is required but not set. Please either:\n"
163+
" 1. Set the environment variable: export OPENAI_API_KEY='your-key-here'\n"
164+
" 2. Pass it as a parameter: OpenAIBackend(api_key='your-key-here')"
165+
)
166+
167+
if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None:
168+
FancyLogger.get_logger().warning(
169+
"OPENAI_BASE_URL or base_url is not set.\n"
170+
"The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`"
171+
)
172+
173+
self._server_type: _ServerType = (
174+
_server_type(self._base_url)
175+
if self._base_url is not None
176+
else _ServerType.OPENAI
177+
) # type: ignore
165178

166179
self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
167180

@@ -598,14 +611,38 @@ async def _generate_from_chat_context_standard(
598611

599612
extra_params: dict[str, Any] = {}
600613
if _format is not None:
601-
extra_params["response_format"] = {
602-
"type": "json_schema",
603-
"json_schema": {
604-
"name": _format.__name__,
605-
"schema": _format.model_json_schema(),
606-
"strict": True,
607-
},
608-
}
614+
if self._server_type == _ServerType.OPENAI:
615+
# The OpenAI platform requires that additionalProperties=False on all response_format schemas.
616+
# However, not all schemas generates by Mellea include additionalProperties.
617+
# GenerativeSlot, in particular, does not add this property.
618+
# The easiest way to address this disparity between OpenAI and other inference providers is to
619+
# monkey-patch the response format exactly when we are actually using the OpenAI server.
620+
#
621+
# This only addresses the additionalProperties=False constraint.
622+
# Other constraints we should be checking/patching are described here:
623+
# https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat
624+
monkey_patched_response_schema = _format.model_json_schema()
625+
monkey_patched_response_schema["additionalProperties"] = False
626+
extra_params["response_format"] = {
627+
"type": "json_schema",
628+
"json_schema": {
629+
"name": _format.__name__,
630+
"schema": monkey_patched_response_schema,
631+
"strict": True,
632+
},
633+
}
634+
else:
635+
FancyLogger().get_logger().warning(
636+
"Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider."
637+
)
638+
extra_params["response_format"] = {
639+
"type": "json_schema",
640+
"json_schema": {
641+
"name": _format.__name__,
642+
"schema": _format.model_json_schema(),
643+
"strict": True,
644+
},
645+
}
609646

610647
# Append tool call information if applicable.
611648
tools: dict[str, Callable] = dict()
@@ -631,15 +668,21 @@ async def _generate_from_chat_context_standard(
631668
formatted_tools = convert_tools_to_json(tools)
632669
use_tools = len(formatted_tools) > 0
633670

671+
# Build optional reasoning parameters
672+
# NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o
673+
reasoning_params = {}
674+
if thinking is not None:
675+
reasoning_params["reasoning_effort"] = thinking
676+
634677
chat_response: Coroutine[
635678
Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk]
636679
] = self._async_client.chat.completions.create(
637-
model=self._hf_model_id,
680+
model=self._model_id,
638681
messages=conversation, # type: ignore
639-
reasoning_effort=thinking, # type: ignore
640682
tools=formatted_tools if use_tools else None, # type: ignore
641683
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
642684
**extra_params,
685+
**reasoning_params, # type: ignore
643686
**self._make_backend_specific_and_remove(
644687
model_opts, is_chat_context=ctx.is_chat_context
645688
),
@@ -807,7 +850,7 @@ async def generate_from_raw(
807850
try:
808851
completion_response: Completion = (
809852
await self._async_client.completions.create(
810-
model=self._hf_model_id,
853+
model=self._model_id,
811854
prompt=prompts,
812855
extra_body=extra_body,
813856
**self._make_backend_specific_and_remove(
@@ -860,7 +903,10 @@ async def generate_from_raw(
860903
@property
861904
def base_model_name(self):
862905
"""Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`."""
863-
return self._hf_model_id.split("/")[1]
906+
if "/" in self._model_id:
907+
return self._model_id.split("/")[1]
908+
else:
909+
return self._model_id
864910

865911
def add_adapter(self, adapter: OpenAIAdapter):
866912
"""Adds the given adapter to the backend. Must not have been added to a different backend."""
@@ -970,22 +1016,3 @@ def list_adapters(self) -> list[str]:
9701016
:returns: list of adapter names that are currently registered with this backend
9711017
"""
9721018
return list(self._loaded_adapters.keys())
973-
974-
def apply_chat_template(self, chat: list[dict[str, str]]):
975-
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
976-
from transformers import AutoTokenizer
977-
978-
if not hasattr(self, "_tokenizer"):
979-
match _server_type(self._base_url):
980-
case _ServerType.LOCALHOST:
981-
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
982-
AutoTokenizer.from_pretrained(self._hf_model_id)
983-
)
984-
case _ServerType.OPENAI:
985-
raise Exception(
986-
"apply_chat_template is called while targeting a server at openai.com. "
987-
"This is not supported --- openai.com does not support Activated Lora. "
988-
"Use a locally served vllm instance. "
989-
)
990-
991-
return self._tokenizer.apply_chat_template(chat, tokenize=False)

test/backends/test_openai_ollama.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# test/rits_backend_tests/test_openai_integration.py
22
import asyncio
33
import os
4+
from unittest.mock import patch
45

56
import openai
67
import pydantic
@@ -216,6 +217,80 @@ async def get_client_async():
216217
assert len(backend._client_cache.cache.values()) == 2
217218

218219

220+
async def test_reasoning_effort_conditional_passing(backend):
221+
"""Test that reasoning_effort is only passed to API when not None."""
222+
from unittest.mock import AsyncMock, MagicMock, patch
223+
224+
ctx = ChatContext()
225+
ctx = ctx.add(CBlock(value="Test"))
226+
227+
mock_response = MagicMock()
228+
mock_response.choices = [MagicMock()]
229+
mock_response.choices[0].message = MagicMock()
230+
mock_response.choices[0].message.content = "Response"
231+
mock_response.choices[0].message.role = "assistant"
232+
233+
# Test 1: reasoning_effort should NOT be passed when not specified
234+
with patch.object(
235+
backend._async_client.chat.completions, "create", new_callable=AsyncMock
236+
) as mock_create:
237+
mock_create.return_value = mock_response
238+
await backend.generate_from_chat_context(
239+
CBlock(value="Hi"), ctx, model_options={}
240+
)
241+
call_kwargs = mock_create.call_args.kwargs
242+
assert "reasoning_effort" not in call_kwargs, (
243+
"reasoning_effort should not be passed when not specified"
244+
)
245+
246+
# Test 2: reasoning_effort SHOULD be passed when specified
247+
with patch.object(
248+
backend._async_client.chat.completions, "create", new_callable=AsyncMock
249+
) as mock_create:
250+
mock_create.return_value = mock_response
251+
await backend.generate_from_chat_context(
252+
CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"}
253+
)
254+
call_kwargs = mock_create.call_args.kwargs
255+
assert call_kwargs.get("reasoning_effort") == "medium", (
256+
"reasoning_effort should be passed with correct value when specified"
257+
)
258+
259+
260+
def test_api_key_and_base_url_from_parameters():
261+
"""Test that API key and base URL can be set via parameters."""
262+
backend = OpenAIBackend(
263+
model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1"
264+
)
265+
assert backend._api_key == "test-api-key"
266+
assert backend._base_url == "https://api.test.com/v1"
267+
268+
269+
def test_parameter_overrides_env_variable():
270+
"""Test that explicit parameters override environment variables."""
271+
with patch.dict(
272+
os.environ,
273+
{"OPENAI_API_KEY": "env-api-key", "OPENAI_BASE_URL": "https://api.env.com/v1"},
274+
):
275+
backend = OpenAIBackend(
276+
model_id="gpt-4",
277+
api_key="param-api-key",
278+
base_url="https://api.param.com/v1",
279+
)
280+
assert backend._api_key == "param-api-key"
281+
assert backend._base_url == "https://api.param.com/v1"
282+
283+
284+
def test_missing_api_key_raises_error():
285+
"""Test that missing API key raises ValueError with helpful message."""
286+
with patch.dict(os.environ, {}, clear=True):
287+
with pytest.raises(ValueError) as exc_info:
288+
OpenAIBackend(model_id="gpt-4", base_url="https://api.test.com/v1")
289+
assert "OPENAI_API_KEY or api_key is required but not set" in str(
290+
exc_info.value
291+
)
292+
293+
219294
if __name__ == "__main__":
220295
import pytest
221296

test/stdlib_intrinsics/test_rag/test_rag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_answer_relevance(backend):
184184
assert result == answer
185185

186186

187+
@pytest.mark.qualitative
187188
def test_answer_relevance_classifier(backend):
188189
"""Verify that the first phase of the answer relevance flow behaves as expectee."""
189190
context, answer, docs = _read_input_json("answer_relevance.json")

0 commit comments

Comments
 (0)