7
7
ChatCompletionAssistantMessageParam ,
8
8
ChatCompletionContentPartParam ,
9
9
ChatCompletionMessageParam ,
10
+ ChatCompletionMessageToolCallParam ,
10
11
ChatCompletionNamedToolChoiceParam ,
11
12
ChatCompletionRole ,
12
13
ChatCompletionSystemMessageParam ,
14
+ ChatCompletionToolMessageParam ,
13
15
ChatCompletionToolParam ,
14
16
ChatCompletionUserMessageParam ,
15
17
)
16
18
17
19
from .model_helper import count_tokens_for_message , count_tokens_for_system_and_tools , get_token_limit
18
20
19
21
20
- def normalize_content (content : Union [str , Iterable [ChatCompletionContentPartParam ]]):
22
+ def normalize_content (content : Union [str , Iterable [ChatCompletionContentPartParam ], None ]):
23
+ if content is None :
24
+ return None
21
25
if isinstance (content , str ):
22
26
return unicodedata .normalize ("NFC" , content )
23
27
else :
@@ -48,7 +52,12 @@ def all_messages(self) -> list[ChatCompletionMessageParam]:
48
52
return [self .system_message ] + self .messages
49
53
50
54
def insert_message (
51
- self , role : ChatCompletionRole , content : Union [str , Iterable [ChatCompletionContentPartParam ]], index : int = 0
55
+ self ,
56
+ role : ChatCompletionRole ,
57
+ content : Union [str , Iterable [ChatCompletionContentPartParam ], None ],
58
+ index : int = 0 ,
59
+ tool_calls : Optional [Iterable [ChatCompletionMessageToolCallParam ]] = None ,
60
+ tool_call_id : Optional [str ] = None ,
52
61
):
53
62
"""
54
63
Inserts a message into the conversation at the specified index,
@@ -63,8 +72,14 @@ def insert_message(
63
72
message = ChatCompletionUserMessageParam (role = "user" , content = normalize_content (content ))
64
73
elif role == "assistant" and isinstance (content , str ):
65
74
message = ChatCompletionAssistantMessageParam (role = "assistant" , content = normalize_content (content ))
75
+ elif role == "assistant" and tool_calls is not None :
76
+ message = ChatCompletionAssistantMessageParam (role = "assistant" , tool_calls = tool_calls )
77
+ elif role == "tool" and tool_call_id is not None :
78
+ message = ChatCompletionToolMessageParam (
79
+ role = "tool" , tool_call_id = tool_call_id , content = normalize_content (content )
80
+ )
66
81
else :
67
- raise ValueError (f "Invalid role: { role } " )
82
+ raise ValueError ("Invalid message for builder " )
68
83
self .messages .insert (index , message )
69
84
70
85
@@ -102,9 +117,17 @@ def build_messages(
102
117
message_builder = _MessageBuilder (system_prompt )
103
118
104
119
for shot in reversed (few_shots ):
105
- if shot ["role" ] is None or shot ["content" ] is None :
106
- raise ValueError ("Few-shot messages must have both role and content" )
107
- message_builder .insert_message (shot ["role" ], shot ["content" ])
120
+ if shot ["role" ] is None or (shot .get ("content" ) is None and shot .get ("tool_calls" ) is None ):
121
+ raise ValueError ("Few-shot messages must have role and either content or tool_calls" )
122
+ tool_call_id = shot .get ("tool_call_id" )
123
+ if tool_call_id is not None and not isinstance (tool_call_id , str ):
124
+ raise ValueError ("tool_call_id must be a string value" )
125
+ tool_calls = shot .get ("tool_calls" )
126
+ if tool_calls is not None and not isinstance (tool_calls , Iterable ):
127
+ raise ValueError ("tool_calls must be a list of tool calls" )
128
+ message_builder .insert_message (
129
+ shot ["role" ], shot .get ("content" ), tool_calls = tool_calls , tool_call_id = tool_call_id
130
+ )
108
131
109
132
append_index = len (few_shots )
110
133
0 commit comments