Skip to content

Commit 717fb74

Browse files
authored
feat: update base/openai chat model module (#70)
* feat: update base/openai chat model module * feat: update base/openai chat model module * feat: update base/openai chat model module * feat: update base/openai chat model module
1 parent 3474982 commit 717fb74

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

openjudge/models/base_chat_model.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from openjudge.models.schema.oai.response import ChatResponse
1212

13-
TOOL_CHOICE_MODES = ["auto", "none", "any", "required"]
13+
TOOL_CHOICE_MODES = {"auto", "none", "any", "required"}
1414

1515

1616
class BaseChatModel(ABC):
@@ -113,13 +113,34 @@ def _validate_tool_choice(
113113
raise TypeError(
114114
f"tool_choice must be str, got {type(tool_choice)}",
115115
)
116+
117+
tool_choice = tool_choice.strip()
118+
if not tool_choice:
119+
raise ValueError("`tool_choice` must be a non-empty string.")
120+
116121
if tool_choice in TOOL_CHOICE_MODES:
117122
return
118123

119-
available_functions = [tool["function"]["name"] for tool in tools] if tools else []
124+
if not tools:
125+
raise ValueError(
126+
f"Tool choice '{tool_choice}' is not a built-in mode ({', '.join(TOOL_CHOICE_MODES)}) "
127+
"and no tools were provided."
128+
)
129+
130+
available_functions = set()
131+
for i, tool in enumerate(tools):
132+
if not isinstance(tool, dict):
133+
raise TypeError(f"Tool at index {i} is not a dictionary.")
134+
func = tool.get("function")
135+
if not isinstance(func, dict):
136+
raise TypeError(f"Tool at index {i} missing or invalid 'function' field.")
137+
name = func.get("name")
138+
if not isinstance(name, str):
139+
raise TypeError(f"Tool function name at index {i} is not a string.")
140+
available_functions.add(name)
120141

121142
if tool_choice not in available_functions:
122-
all_options = TOOL_CHOICE_MODES + available_functions
143+
all_options = sorted(TOOL_CHOICE_MODES | available_functions)
123144
raise ValueError(
124-
f"Invalid tool_choice '{tool_choice}'. " f"Available options: {', '.join(sorted(all_options))}",
145+
f"Invalid tool_choice '{tool_choice}'. " f"Available options: {', '.join(all_options)}",
125146
)

openjudge/models/openai_chat_model.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""OpenAI Client."""
3+
import copy
34
import os
45
from typing import Any, AsyncGenerator, Callable, Dict, Literal, Type
56

@@ -13,28 +14,36 @@
1314
from openjudge.utils.utils import repair_and_load_json
1415

1516

16-
def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> None:
17+
def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> list[dict]:
1718
"""Qwen-omni uses OpenAI-compatible API but requires different audio
1819
data format than OpenAI with "data:;base64," prefix.
1920
Refer to `Qwen-omni documentation
20-
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
21+
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`
2122
for more details.
2223
2324
Args:
2425
messages (`list[dict]`):
2526
The list of message dictionaries from OpenAI formatter.
2627
"""
28+
format_data = []
2729
for msg in messages:
28-
msg_dict = msg.to_dict() if isinstance(msg, ChatMessage) else msg
29-
if isinstance(msg_dict.get("content"), list):
30-
for block in msg_dict["content"]:
31-
if (
32-
isinstance(block, dict)
33-
and "input_audio" in block
34-
and isinstance(block["input_audio"].get("data"), str)
35-
):
36-
if not block["input_audio"]["data"].startswith("http"):
30+
try:
31+
msg_copy = copy.deepcopy(msg)
32+
msg_dict = msg_copy.to_dict() if isinstance(msg_copy, ChatMessage) else msg_copy
33+
if isinstance(msg_dict.get("content"), list):
34+
for block in msg_dict["content"]:
35+
if (
36+
isinstance(block, dict)
37+
and "input_audio" in block
38+
and isinstance(block["input_audio"].get("data"), str)
39+
and not block["input_audio"]["data"].startswith("http")
40+
):
3741
block["input_audio"]["data"] = "data:;base64," + block["input_audio"]["data"]
42+
format_data.append(msg_dict)
43+
except Exception as e:
44+
logger.error(f"Failed to format audio data: {type(e).__name__}: {e}", exc_info=True)
45+
format_data.append(msg.to_dict() if isinstance(msg, ChatMessage) else msg)
46+
return format_data
3847

3948

4049
class OpenAIChatModel(BaseChatModel):
@@ -150,7 +159,7 @@ async def achat(
150159

151160
# Qwen-omni requires different base64 audio format from openai
152161
if "omni" in self.model.lower():
153-
_format_audio_data_for_qwen_omni(messages)
162+
messages = _format_audio_data_for_qwen_omni(messages)
154163

155164
kwargs = {
156165
"model": self.model,
@@ -188,9 +197,14 @@ async def achat(
188197

189198
# Use simple json_object format for models that don't support complex JSON schema
190199
if "qwen" in self.model.lower() or "gemini" in self.model.lower():
191-
structured_model = {"type": "json_object"} # type: ignore
200+
logger.warning(
201+
"Qwen models do not support Pydantic structured output via `response_format`. "
202+
"Update the unstructured JSON mode with `response_format={'type': 'json_object'}`."
203+
)
204+
structured_model = {"type": "json_object"}
192205

193206
kwargs["response_format"] = structured_model
207+
194208
if not self.stream:
195209
response = await self.client.chat.completions.parse(**kwargs)
196210
else:

tests/models/test_openai_chat_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_qwen_omni_audio_formatting(self):
311311
]
312312

313313
# Apply the transformation
314-
_format_audio_data_for_qwen_omni(messages)
314+
messages = _format_audio_data_for_qwen_omni(messages)
315315

316316
# Check that the data was formatted correctly
317317
assert messages[0]["content"][0]["input_audio"]["data"].startswith(

0 commit comments

Comments
 (0)