Skip to content

Commit d2677d3

Browse files
authored
Gordonlim tlm chat typing fix (#107)
1 parent af852d0 commit d2677d3

File tree

5 files changed

+68
-23
lines changed

5 files changed

+68
-23
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extra-dependencies = [
4848
"pytest-asyncio",
4949
"python-dotenv",
5050
"tiktoken",
51-
"openai<=1.99.1",
51+
"openai",
5252
]
5353
[tool.hatch.envs.types.scripts]
5454
check = "mypy --strict --install-types --non-interactive {args:src/cleanlab_tlm tests}"
@@ -60,7 +60,7 @@ extra-dependencies = [
6060
"python-dotenv",
6161
"pytest-asyncio",
6262
"tiktoken",
63-
"openai<=1.99.1",
63+
"openai",
6464
]
6565

6666
[tool.hatch.envs.hatch-test.env-vars]

src/cleanlab_tlm/utils/chat.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def _form_prompt_chat_completions_api(
339339

340340
# Only return content directly if there's a single user message AND no tools
341341
if len(messages) == 1 and messages[0].get("role") == _USER_ROLE and (tools is None or len(tools) == 0):
342-
return output + str(messages[0]["content"])
342+
first_msg = cast(dict[str, Any], messages[0])
343+
return output + str(first_msg["content"])
343344

344345
# Warn if the last message is an assistant message with tool calls
345346
if messages and (messages[-1].get("role") == _ASSISTANT_ROLE or "tool_calls" in messages[-1]):
@@ -358,22 +359,26 @@ def _form_prompt_chat_completions_api(
358359
if msg["role"] == _ASSISTANT_ROLE:
359360
output += _ASSISTANT_PREFIX
360361
# Handle content if present
361-
if msg.get("content"):
362-
output += f"{msg['content']}\n\n"
362+
content_value = cast(Optional[str], msg.get("content"))
363+
if content_value:
364+
output += f"{content_value}\n\n"
363365
# Handle tool calls if present
364366
if "tool_calls" in msg:
365367
for tool_call in msg["tool_calls"]:
366-
call_id = tool_call["id"]
367-
function_names[call_id] = tool_call["function"]["name"]
368-
# Format function call as JSON within XML tags, now including call_id
369-
function_call = {
370-
"name": tool_call["function"]["name"],
371-
"arguments": json.loads(tool_call["function"]["arguments"])
372-
if tool_call["function"]["arguments"]
373-
else {},
374-
"call_id": call_id,
375-
}
376-
output += f"{_TOOL_CALL_TAG_START}\n{json.dumps(function_call, indent=2)}\n{_TOOL_CALL_TAG_END}\n\n"
368+
if tool_call["type"] == "function":
369+
call_id = tool_call["id"]
370+
function_names[call_id] = tool_call["function"]["name"]
371+
# Format function call as JSON within XML tags, now including call_id
372+
function_call = {
373+
"name": tool_call["function"]["name"],
374+
"arguments": json.loads(tool_call["function"]["arguments"])
375+
if tool_call["function"]["arguments"]
376+
else {},
377+
"call_id": call_id,
378+
}
379+
output += (
380+
f"{_TOOL_CALL_TAG_START}\n{json.dumps(function_call, indent=2)}\n{_TOOL_CALL_TAG_END}\n\n"
381+
)
377382
elif msg["role"] == _TOOL_ROLE:
378383
# Handle tool responses
379384
output += _TOOL_PREFIX
@@ -506,13 +511,19 @@ def form_response_string_chat_completions_api(
506511
"""
507512
response_dict = _response_to_dict(response)
508513
content = response_dict.get("content") or ""
509-
tool_calls = response_dict.get("tool_calls")
514+
tool_calls = cast(Optional[list[dict[str, Any]]], response_dict.get("tool_calls"))
510515
if tool_calls is not None:
511516
try:
512-
tool_calls_str = "\n".join(
513-
f"{_TOOL_CALL_TAG_START}\n{json.dumps({'name': call['function']['name'], 'arguments': json.loads(call['function']['arguments']) if call['function']['arguments'] else {}}, indent=2)}\n{_TOOL_CALL_TAG_END}"
514-
for call in tool_calls
515-
)
517+
rendered_calls: list[str] = []
518+
for call in tool_calls:
519+
function_dict = call["function"]
520+
name = cast(str, function_dict["name"])
521+
args_str = cast(Optional[str], function_dict.get("arguments"))
522+
args_obj = json.loads(args_str) if args_str else {}
523+
rendered_calls.append(
524+
f"{_TOOL_CALL_TAG_START}\n{json.dumps({'name': name, 'arguments': args_obj}, indent=2)}\n{_TOOL_CALL_TAG_END}"
525+
)
526+
tool_calls_str = "\n".join(rendered_calls)
516527
return f"{content}\n{tool_calls_str}".strip() if content else tool_calls_str
517528
except (KeyError, TypeError, json.JSONDecodeError) as e:
518529
# Log the error but continue with just the content

tests/openai_compat.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Compatibility layer for OpenAI SDK tool call types across versions.
2+
3+
This module exposes two names for tests to import:
4+
- ChatCompletionMessageToolCall: the class representing a function tool call
5+
In previous versions, this was called `ChatCompletionMessageToolCall`,
6+
but that has now become a union for both Custom Tool Calls and Function Tool Calls,
7+
which have different schemas.
8+
This is a shim to allow tests to work with both types.
9+
- Function: the inner function payload model
10+
The Function model is only used for Function Tool Calls, which are currently supported
11+
by this package.
12+
13+
Works with SDKs that expose either the legacy
14+
`chat_completion_message_tool_call.ChatCompletionMessageToolCall` or the
15+
newer `chat_completion_message_function_tool_call.ChatCompletionMessageFunctionToolCall`.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
try: # OpenAI SDK >= 1.99.2
21+
from openai.types.chat.chat_completion_message_function_tool_call import (
22+
ChatCompletionMessageFunctionToolCall as ChatCompletionMessageToolCall,
23+
)
24+
from openai.types.chat.chat_completion_message_function_tool_call import (
25+
Function,
26+
)
27+
except Exception: # OpenAI SDK <= 1.99.1
28+
import importlib
29+
30+
_legacy = importlib.import_module("openai.types.chat.chat_completion_message_tool_call")
31+
ChatCompletionMessageToolCall = _legacy.ChatCompletionMessageToolCall # type: ignore
32+
Function = _legacy.Function # type: ignore
33+
34+
__all__ = ["ChatCompletionMessageToolCall", "Function"]

tests/test_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
from openai.types.chat import ChatCompletion, ChatCompletionMessage
55
from openai.types.chat.chat_completion import Choice
6-
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
76

87
from cleanlab_tlm.utils.chat import (
98
_form_prompt_chat_completions_api,
@@ -12,6 +11,7 @@
1211
form_response_string_chat_completions,
1312
form_response_string_chat_completions_api,
1413
)
14+
from tests.openai_compat import ChatCompletionMessageToolCall, Function
1515

1616
if TYPE_CHECKING:
1717
from openai.types.chat import ChatCompletionMessageParam

tests/test_chat_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
from openai.types.chat import ChatCompletion, ChatCompletionMessage
66
from openai.types.chat.chat_completion import Choice
7-
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
87
from openai.types.completion_usage import (
98
CompletionTokensDetails,
109
CompletionUsage,
@@ -17,6 +16,7 @@
1716
from cleanlab_tlm.utils.chat_completions import TLMChatCompletion
1817
from tests.conftest import make_text_unique
1918
from tests.constants import TEST_PROMPT, TEST_RESPONSE
19+
from tests.openai_compat import ChatCompletionMessageToolCall, Function
2020
from tests.test_get_trustworthiness_score import is_trustworthiness_score_json_format
2121

2222
test_prompt = make_text_unique(TEST_PROMPT)

0 commit comments

Comments
 (0)