diff --git a/models/openai_api_compatible/manifest.yaml b/models/openai_api_compatible/manifest.yaml index 9b3d6bae0..27509b8c3 100644 --- a/models/openai_api_compatible/manifest.yaml +++ b/models/openai_api_compatible/manifest.yaml @@ -1,4 +1,4 @@ -version: 0.0.22 +version: 0.0.23 type: plugin author: "langgenius" name: "openai_api_compatible" diff --git a/models/openai_api_compatible/models/llm/llm.py b/models/openai_api_compatible/models/llm/llm.py index 05ca7b6db..bef2c839b 100644 --- a/models/openai_api_compatible/models/llm/llm.py +++ b/models/openai_api_compatible/models/llm/llm.py @@ -1,6 +1,9 @@ import re from contextlib import suppress -from typing import Mapping, Optional, Union, Generator +from typing import Mapping, Optional, Union, Generator, Any +from pydantic import TypeAdapter +import requests +from urllib.parse import urljoin from dify_plugin.entities.model import ( AIModelEntity, @@ -17,7 +20,12 @@ PromptMessageTool, SystemPromptMessage, AssistantPromptMessage, + PromptMessageFunction ) +from dify_plugin.entities.model.llm import ( + LLMMode, +) +from dify_plugin.errors.model import InvokeError from dify_plugin.interfaces.model.openai_compatible.llm import OAICompatLargeLanguageModel from typing import List @@ -26,6 +34,122 @@ class OpenAILargeLanguageModel(OAICompatLargeLanguageModel): # Pre-compiled regex for better performance _THINK_PATTERN = re.compile(r"^.*?\s*", re.DOTALL) + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke llm completion model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + headers = { + "Content-Type": "application/json", + "Accept-Charset": "utf-8", + } + extra_headers = credentials.get("extra_headers") + if extra_headers is not None: + headers = { + **headers, + **extra_headers, + } + + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" + + response_format = model_parameters.get("response_format") + if response_format: + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = TypeAdapter(dict[str, Any]).validate_json(json_schema) + except Exception as exc: + raise ValueError(f"not correct json_schema format: {json_schema}") from exc + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} + else: + model_parameters["response_format"] = {"type": response_format} + elif "json_schema" in model_parameters: + del model_parameters["json_schema"] + + data = {"model": model, "stream": stream, **model_parameters} + + # request usage data in streaming mode for token counting + if stream: + data["stream_options"] = {"include_usage": True} + + completion_type = LLMMode.value_of(credentials["mode"]) + + if completion_type is LLMMode.CHAT: + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + elif completion_type is LLMMode.COMPLETION: + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content + else: + raise ValueError("Unsupported completion type for model configuration.") + + # annotate tools with names, descriptions, etc. + function_calling_type = credentials.get("function_calling_type", "no_call") + formatted_tools = [] + if tools: + if function_calling_type == "function_call": + data["functions"] = [ + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in tools + ] + elif function_calling_type == "tool_call": + data["tool_choice"] = "auto" + + for tool in tools: + formatted_tools.append(PromptMessageFunction(function=tool).model_dump()) + + data["tools"] = formatted_tools + + if stop: + data["stop"] = stop + + if user: + data["user"] = user + + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) + + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" + + if response.status_code != 200: + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + def get_customizable_model_schema( self, model: str, credentials: Mapping | dict ) -> AIModelEntity: