diff --git a/patchwork/common/client/llm/openai_.py b/patchwork/common/client/llm/openai_.py index 6573ee24e..604664a92 100644 --- a/patchwork/common/client/llm/openai_.py +++ b/patchwork/common/client/llm/openai_.py @@ -33,6 +33,7 @@ class OpenAiLlmClient(LlmClient): "gpt-3.5-turbo": 16_385, "gpt-4": 8_192, "gpt-4-turbo": 8_192, + "o1-preview": 128_000, "o1-mini": 128_000, "gpt-4o-mini": 128_000, "gpt-4o": 128_000, @@ -137,4 +138,82 @@ def chat_completion( top_p=top_p, ) + is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in [ + "json_object", + "json_schema", + ] + if model.startswith("o1") and is_json_output_required: + return self.__o1_chat_completion(**input_kwargs) + return self.client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs)) + + def __o1_chat_completion( + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + ): + o1_messages = list(messages) + if response_format.get("type") == "json_schema": + last_msg_idx = len(o1_messages) - 1 + last_msg = o1_messages[last_msg_idx] + last_msg["content"] = ( + last_msg["content"] + + f""" +Respond with the following json schema in mind: +{response_format.get('json_schema')} +""" + ) + o1_input_kwargs = dict( + messages=o1_messages, + model=model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stop=stop, + temperature=temperature, + top_logprobs=top_logprobs, + top_p=top_p, + ) + + o1_response = self.client.chat.completions.create(**NotGiven.remove_not_given(o1_input_kwargs)) + + o1_choices_parser_responses = [] + for o1_choice in o1_response.choices: + parser_input_kwargs = dict( + messages=[ + { + "role": "user", + "content": f"Given the following data, format it with the given response format: {o1_choice.message.content}", + } + ], + model="gpt-4o-mini", + max_tokens=max_tokens, + n=1, + response_format=response_format, + ) + parser_response = self.client.beta.chat.completions.parse(**NotGiven.remove_not_given(parser_input_kwargs)) + o1_choices_parser_responses.append(parser_response) + + reconstructed_response = o1_response.model_copy() + for i, o1_choices_parser_response in enumerate(o1_choices_parser_responses): + if reconstructed_response.usage is not None: + reconstructed_response.usage.completion_tokens += o1_choices_parser_response.usage.completion_tokens + reconstructed_response.usage.prompt_tokens += o1_choices_parser_response.usage.prompt_tokens + reconstructed_response.usage.total_tokens += o1_choices_parser_response.usage.total_tokens + reconstructed_response.choices[i].message.content = o1_choices_parser_response.choices[0].message.content + + return reconstructed_response