diff --git a/patchwork/common/client/llm/anthropic.py b/patchwork/common/client/llm/anthropic.py index 081445c67..0bc7f98ae 100644 --- a/patchwork/common/client/llm/anthropic.py +++ b/patchwork/common/client/llm/anthropic.py @@ -128,22 +128,22 @@ def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) return new_messages def __adapt_chat_completion_request( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ): system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN adapted_messages = self.__adapt_input_messages(messages) @@ -171,6 +171,9 @@ def __adapt_chat_completion_request( elif tool_choice_type == "none": tool_choice = NOT_GIVEN + anthropic_tools = NOT_GIVEN + if tools is not None and tools is not NOT_GIVEN: + anthropic_tools = [tool.get("function") for tool in tools if tool.get("function") is not None] input_kwargs = dict( messages=adapted_messages, system=system, @@ -178,7 +181,7 @@ def __adapt_chat_completion_request( model=model, stop_sequences=[stop] if isinstance(stop, str) else stop, temperature=temperature, - tools=[tool.get("function") for tool in tools if tool.get("function") is not None], + tools=anthropic_tools, tool_choice=tool_choice, top_p=top_p, ) @@ -204,22 +207,22 @@ def is_model_supported(self, model: str) -> bool: return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix) def is_prompt_supported( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ) -> int: model_limit = self.__get_model_limit(model) input_kwargs = self.__adapt_chat_completion_request( @@ -248,27 +251,27 @@ def is_prompt_supported( return model_limit - message_token_count.input_tokens def truncate_messages( - self, messages: Iterable[ChatCompletionMessageParam], model: str + self, messages: Iterable[ChatCompletionMessageParam], model: str ) -> Iterable[ChatCompletionMessageParam]: return self._truncate_messages(self, messages, model) def chat_completion( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ) -> ChatCompletion: input_kwargs = self.__adapt_chat_completion_request( messages=messages, diff --git a/pyproject.toml b/pyproject.toml index 32eebbbc6..740a88437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.80" +version = "0.0.81" description = "" authors = ["patched.codes"] license = "AGPL"