55from functools import lru_cache
66
77from anthropic import Anthropic
8- from anthropic .types import Message , TextBlockParam
8+ from anthropic .types import Message , MessageParam , TextBlockParam
99from openai .types .chat import (
1010 ChatCompletion ,
1111 ChatCompletionMessage ,
1212 ChatCompletionMessageParam ,
13+ ChatCompletionToolChoiceOptionParam ,
14+ ChatCompletionToolParam ,
1315 completion_create_params ,
1416)
1517from openai .types .chat .chat_completion import Choice , CompletionUsage
@@ -81,23 +83,169 @@ def __get_model_limit(self, model: str) -> int:
8183 return 100_000 - safety_margin
8284 return 200_000 - safety_margin
8385
86+ def __adapt_input_messages (self , messages : Iterable [ChatCompletionMessageParam ]) -> list [MessageParam ]:
87+ new_messages = []
88+ for message in messages :
89+ if message .get ("role" ) == "system" :
90+ if system is NOT_GIVEN :
91+ system = list ()
92+ system .append (TextBlockParam (text = message .get ("content" ), type = "text" ))
93+ elif message .get ("role" ) == "tool" :
94+ new_messages .append (
95+ dict (
96+ role = "user" ,
97+ content = [
98+ dict (
99+ type = "tool_result" ,
100+ tool_use_id = message .get ("tool_call_id" ),
101+ content = message .get ("content" ),
102+ )
103+ ],
104+ )
105+ )
106+ elif message .get ("role" ) == "assistant" and len (message .get ("tool_calls" , [])) > 0 :
107+ tool_calls = message ["tool_calls" ]
108+ tool_calls_as_content = [
109+ dict (
110+ type = "tool_use" ,
111+ id = tool_call ["id" ],
112+ name = tool_call ["function" ]["name" ],
113+ input = json .loads (tool_call ["function" ]["arguments" ]),
114+ )
115+ for tool_call in tool_calls
116+ ]
117+ new_messages .append (
118+ dict (
119+ role = "assistant" ,
120+ content = [
121+ * tool_calls_as_content ,
122+ ],
123+ )
124+ )
125+ else :
126+ new_messages .append (message )
127+
128+ return new_messages
129+
130+ def __adapt_chat_completion_request (
131+ self ,
132+ messages : Iterable [ChatCompletionMessageParam ],
133+ model : str ,
134+ frequency_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
135+ logit_bias : Optional [Dict [str , int ]] | NotGiven = NOT_GIVEN ,
136+ logprobs : Optional [bool ] | NotGiven = NOT_GIVEN ,
137+ max_tokens : Optional [int ] | NotGiven = NOT_GIVEN ,
138+ n : Optional [int ] | NotGiven = NOT_GIVEN ,
139+ presence_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
140+ response_format : completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
141+ stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
142+ temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
143+ tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
144+ tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
145+ top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
146+ top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
147+ ):
148+ system : Union [str , Iterable [TextBlockParam ]] | NotGiven = NOT_GIVEN
149+ adapted_messages = self .__adapt_input_messages (messages )
150+ default_max_token = 1000
151+
152+ if tool_choice is not NOT_GIVEN :
153+ # openai tool choice to anthropic tool choice mapping:
154+ # openai : none, auto, required , required
155+ # anthropic: NA , auto, any , tool
156+ if isinstance (tool_choice , str ):
157+ if tool_choice == "required" :
158+ tool_choice = dict (type = "any" )
159+ elif tool_choice == "none" :
160+ tool_choice = NOT_GIVEN
161+ else :
162+ tool_choice = dict (type = tool_choice )
163+ else :
164+ tool_choice_type = tool_choice .get ("type" )
165+ if tool_choice_type == "required" :
166+ if tool_choice .get ("function" ) is not None :
167+ tool_choice ["type" ] = "tool"
168+ tool_choice ["name" ] = tool_choice ["function" ]["name" ]
169+ else :
170+ tool_choice ["type" ] = "any"
171+ elif tool_choice_type == "none" :
172+ tool_choice = NOT_GIVEN
173+
174+ input_kwargs = dict (
175+ messages = adapted_messages ,
176+ system = system ,
177+ max_tokens = default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens ,
178+ model = model ,
179+ stop_sequences = [stop ] if isinstance (stop , str ) else stop ,
180+ temperature = temperature ,
181+ tools = [tool .get ("function" ) for tool in tools if tool .get ("function" ) is not None ],
182+ tool_choice = tool_choice ,
183+ top_p = top_p ,
184+ )
185+
186+ if response_format is not NOT_GIVEN and response_format .get ("type" ) == "json_schema" :
187+ input_kwargs ["tool_choice" ] = dict (type = "tool" , name = "response_format" )
188+ if input_kwargs .get ("tools" ) is NOT_GIVEN :
189+ input_kwargs ["tools" ] = list ()
190+ response_format_tool = dict (
191+ name = "response_format" ,
192+ description = "The response format to use" ,
193+ input_schema = response_format ["json_schema" ]["schema" ],
194+ )
195+ input_kwargs ["tools" ] = [* input_kwargs ["tools" ], response_format_tool ]
196+
197+ return NotGiven .remove_not_given (input_kwargs )
198+
84199 @lru_cache (maxsize = None )
85200 def get_models (self ) -> set [str ]:
86201 return self .__definitely_allowed_models .union (set (f"{ self .__allowed_model_prefix } *" ))
87202
88203 def is_model_supported (self , model : str ) -> bool :
89204 return model in self .__definitely_allowed_models or model .startswith (self .__allowed_model_prefix )
90205
91- def is_prompt_supported (self , messages : Iterable [ChatCompletionMessageParam ], model : str ) -> int :
206+ def is_prompt_supported (
207+ self ,
208+ messages : Iterable [ChatCompletionMessageParam ],
209+ model : str ,
210+ frequency_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
211+ logit_bias : Optional [Dict [str , int ]] | NotGiven = NOT_GIVEN ,
212+ logprobs : Optional [bool ] | NotGiven = NOT_GIVEN ,
213+ max_tokens : Optional [int ] | NotGiven = NOT_GIVEN ,
214+ n : Optional [int ] | NotGiven = NOT_GIVEN ,
215+ presence_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
216+ response_format : completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
217+ stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
218+ temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
219+ tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
220+ tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
221+ top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
222+ top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
223+ ) -> int :
92224 model_limit = self .__get_model_limit (model )
93- token_count = 0
94- for message in messages :
95- message_token_count = self .client .count_tokens (message .get ("content" ))
96- token_count = token_count + message_token_count
97- if token_count > model_limit :
98- return - 1
99-
100- return model_limit - token_count
225+ input_kwargs = self .__adapt_chat_completion_request (
226+ messages = messages ,
227+ model = model ,
228+ frequency_penalty = frequency_penalty ,
229+ logit_bias = logit_bias ,
230+ logprobs = logprobs ,
231+ max_tokens = max_tokens ,
232+ n = n ,
233+ presence_penalty = presence_penalty ,
234+ response_format = response_format ,
235+ stop = stop ,
236+ temperature = temperature ,
237+ tools = tools ,
238+ tool_choice = tool_choice ,
239+ top_logprobs = top_logprobs ,
240+ top_p = top_p ,
241+ )
242+ count_token_input_kwargs = {
243+ k : v
244+ for k , v in input_kwargs .items ()
245+ if k in {"messages" , "model" , "system" , "tool_choice" , "tools" , "beta" }
246+ }
247+ message_token_count = self .client .beta .messages .count_tokens (** count_token_input_kwargs )
248+ return model_limit - message_token_count .input_tokens
101249
102250 def truncate_messages (
103251 self , messages : Iterable [ChatCompletionMessageParam ], model : str
@@ -117,38 +265,28 @@ def chat_completion(
117265 response_format : completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
118266 stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
119267 temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
268+ tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
269+ tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
120270 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
121271 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
122272 ) -> ChatCompletion :
123- system : Union [str , Iterable [TextBlockParam ]] | NotGiven = NOT_GIVEN
124- other_messages = []
125- for message in messages :
126- if message .get ("role" ) == "system" :
127- if system is NOT_GIVEN :
128- system = list ()
129- system .append (TextBlockParam (text = message .get ("content" ), type = "text" ))
130- else :
131- other_messages .append (message )
132-
133- default_max_token = 1000
134- input_kwargs = dict (
135- messages = other_messages ,
136- system = system ,
137- max_tokens = default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens ,
273+ input_kwargs = self .__adapt_chat_completion_request (
274+ messages = messages ,
138275 model = model ,
139- stop_sequences = [stop ] if isinstance (stop , str ) else stop ,
276+ frequency_penalty = frequency_penalty ,
277+ logit_bias = logit_bias ,
278+ logprobs = logprobs ,
279+ max_tokens = max_tokens ,
280+ n = n ,
281+ presence_penalty = presence_penalty ,
282+ response_format = response_format ,
283+ stop = stop ,
140284 temperature = temperature ,
285+ tools = tools ,
286+ tool_choice = tool_choice ,
287+ top_logprobs = top_logprobs ,
141288 top_p = top_p ,
142289 )
143- if response_format is not NOT_GIVEN and response_format .get ("type" ) == "json_schema" :
144- input_kwargs ["tool_choice" ] = dict (type = "tool" , name = "response_format" )
145- input_kwargs ["tools" ] = [
146- dict (
147- name = "response_format" ,
148- description = "The response format to use" ,
149- input_schema = response_format ["json_schema" ]["schema" ],
150- )
151- ]
152290
153- response = self .client .messages .create (** NotGiven . remove_not_given ( input_kwargs ) )
291+ response = self .client .messages .create (** input_kwargs )
154292 return _anthropic_to_openai_response (model , response )
0 commit comments