|
7 | 7 | ChatCompletionAssistantMessageParam,
|
8 | 8 | ChatCompletionContentPartParam,
|
9 | 9 | ChatCompletionMessageParam,
|
| 10 | + ChatCompletionMessageToolCallParam, |
10 | 11 | ChatCompletionNamedToolChoiceParam,
|
11 | 12 | ChatCompletionRole,
|
12 | 13 | ChatCompletionSystemMessageParam,
|
|
18 | 19 | from .model_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit
|
19 | 20 |
|
20 | 21 |
|
21 |
| -def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam]]): |
| 22 | +def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam], None]): |
| 23 | + if content is None: |
| 24 | + return None |
22 | 25 | if isinstance(content, str):
|
23 | 26 | return unicodedata.normalize("NFC", content)
|
24 | 27 | else:
|
@@ -51,9 +54,9 @@ def all_messages(self) -> list[ChatCompletionMessageParam]:
|
51 | 54 | def insert_message(
|
52 | 55 | self,
|
53 | 56 | role: ChatCompletionRole,
|
54 |
| - content: Union[str, Iterable[ChatCompletionContentPartParam]], |
| 57 | + content: Union[str, Iterable[ChatCompletionContentPartParam], None], |
55 | 58 | index: int = 0,
|
56 |
| - tool_calls: Optional[list[ChatCompletionToolParam]] = None, |
| 59 | + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] = None, |
57 | 60 | tool_call_id: Optional[str] = None,
|
58 | 61 | ):
|
59 | 62 | """
|
@@ -116,8 +119,14 @@ def build_messages(
|
116 | 119 | for shot in reversed(few_shots):
|
117 | 120 | if shot["role"] is None or (shot.get("content") is None and shot.get("tool_calls") is None):
|
118 | 121 | raise ValueError("Few-shot messages must have role and either content or tool_calls")
|
| 122 | + tool_call_id = shot.get("tool_call_id") |
| 123 | + if tool_call_id is not None and not isinstance(tool_call_id, str): |
| 124 | + raise ValueError("tool_call_id must be a string value") |
| 125 | + tool_calls = shot.get("tool_calls") |
| 126 | + if tool_calls is not None and not isinstance(tool_calls, Iterable): |
| 127 | + raise ValueError("tool_calls must be a list of tool calls") |
119 | 128 | message_builder.insert_message(
|
120 |
| - shot["role"], shot.get("content"), tool_calls=shot.get("tool_calls"), tool_call_id=shot.get("tool_call_id") |
| 129 | + shot["role"], shot.get("content"), tool_calls=tool_calls, tool_call_id=tool_call_id |
121 | 130 | )
|
122 | 131 |
|
123 | 132 | append_index = len(few_shots)
|
|
0 commit comments