diff --git a/dapr_agents/llm/__init__.py b/dapr_agents/llm/__init__.py index 1f7613c0..90a3060b 100644 --- a/dapr_agents/llm/__init__.py +++ b/dapr_agents/llm/__init__.py @@ -9,4 +9,5 @@ from .nvidia.client import NVIDIAClientBase from .nvidia.chat import NVIDIAChatClient from .nvidia.embeddings import NVIDIAEmbeddingClient -from .elevenlabs import ElevenLabsSpeechClient \ No newline at end of file +from .elevenlabs import ElevenLabsSpeechClient +from .dapr import DaprChatClient \ No newline at end of file diff --git a/dapr_agents/llm/dapr/__init__.py b/dapr_agents/llm/dapr/__init__.py new file mode 100644 index 00000000..40b69cd1 --- /dev/null +++ b/dapr_agents/llm/dapr/__init__.py @@ -0,0 +1,2 @@ +from .chat import DaprChatClient +from .client import DaprInferenceClientBase \ No newline at end of file diff --git a/dapr_agents/llm/dapr/chat.py b/dapr_agents/llm/dapr/chat.py new file mode 100644 index 00000000..517e7275 --- /dev/null +++ b/dapr_agents/llm/dapr/chat.py @@ -0,0 +1,152 @@ +from dapr_agents.llm.dapr.client import DaprInferenceClientBase +from dapr_agents.llm.utils import RequestHandler, ResponseHandler +from dapr_agents.prompt.prompty import Prompty +from dapr_agents.types.message import BaseMessage +from dapr_agents.llm.chat import ChatClientBase +from dapr_agents.tool import AgentTool +from dapr.clients.grpc._request import ConversationInput +from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type +from pydantic import BaseModel +from pathlib import Path +import logging +import os +import time + +logger = logging.getLogger(__name__) + +class DaprChatClient(DaprInferenceClientBase, ChatClientBase): + """ + Concrete class for Dapr's chat completion API using the Inference API. + This class extends the ChatClientBase. + """ + + def model_post_init(self, __context: Any) -> None: + """ + Initializes private attributes for provider, api, config, and client after validation. + """ + # Set the private provider and api attributes + self._api = "chat" + self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT'] + + return super().model_post_init(__context) + + @classmethod + def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient': + """ + Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content. + + Args: + prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file + or inline Prompty content as a string. + timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds. + + Returns: + DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source. + """ + # Load the Prompty instance from the provided source + prompty_instance = Prompty.load(prompty_source) + + # Generate the prompt template from the Prompty instance + prompt_template = Prompty.to_prompt_template(prompty_instance) + + # Initialize the DaprChatClient based on the Prompty model configuration + return cls.model_validate({ + 'timeout': timeout, + 'prompty': prompty_instance, + 'prompt_template': prompt_template, + }) + + def translate_response(self, response: dict, model: str) -> dict: + """Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion.""" + choices = [ + { + "finish_reason": "stop", + "index": i, + "message": { + "content": output["result"], + "role": "assistant" + }, + "logprobs": None + } + for i, output in enumerate(response.get("outputs", [])) + ] + + return { + "choices": choices, + "created": int(time.time()), + "model": model, + "object": "chat.completion", + "usage": {"total_tokens": "-1"} + } + + def convert_to_conversation_inputs(self, inputs: List[Dict[str, Any]]) -> List[ConversationInput]: + return [ + ConversationInput( + content=item["content"], + role=item.get("role"), + scrub_pii=item.get("scrubPII") == "true" + ) + for item in inputs + ] + + def generate( + self, + messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None, + input_data: Optional[Dict[str, Any]] = None, + llm_component: Optional[str] = None, + tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None, + response_model: Optional[Type[BaseModel]] = None, + scrubPII: Optional[bool] = False, + temperature: Optional[float] = None, + **kwargs + ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]: + """ + Generate chat completions based on provided messages or input_data for prompt templates. + + Args: + messages (Optional): Either pre-set messages or None if using input_data. + input_data (Optional[Dict[str, Any]]): Input variables for prompt templates. + llm_component (str): Name of the LLM component to use for the request. + tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request. + response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing. + scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM. + **kwargs: Additional parameters for the language model. + + Returns: + Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s). + """ + + # If input_data is provided, check for a prompt_template + if input_data: + if not self.prompt_template: + raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.") + + logger.info("Using prompt template to generate messages.") + messages = self.prompt_template.format_prompt(**input_data) + + # Ensure we have messages at this point + if not messages: + raise ValueError("Either 'messages' or 'input_data' must be provided.") + + # Process and normalize the messages + params = {'inputs': RequestHandler.normalize_chat_messages(messages)} + # Merge Prompty parameters if available, then override with any explicit kwargs + if self.prompty: + params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs} + else: + params.update(kwargs) + + # Prepare and send the request + params = RequestHandler.process_params(params, llm_provider=self.provider, tools=tools, response_model=response_model) + inputs = self.convert_to_conversation_inputs(params['inputs']) + + try: + logger.info("Invoking the Dapr Conversation API.") + response = self.client.chat_completion(llm=llm_component or self._llm_component, conversation_inputs=inputs, scrub_pii=scrubPII, temperature=temperature) + transposed_response = self.translate_response(response, self._llm_component) + logger.info("Chat completion retrieved successfully.") + + return ResponseHandler.process_response(transposed_response, llm_provider=self.provider, response_model=response_model, stream=params.get('stream', False)) + except Exception as e: + logger.error(f"An error occurred during the Dapr Conversation API call: {e}") + raise \ No newline at end of file diff --git a/dapr_agents/llm/dapr/client.py b/dapr_agents/llm/dapr/client.py new file mode 100644 index 00000000..f0609f02 --- /dev/null +++ b/dapr_agents/llm/dapr/client.py @@ -0,0 +1,86 @@ +from dapr_agents.types.llm import DaprInferenceClientConfig +from dapr_agents.llm.base import LLMClientBase +from dapr.clients import DaprClient +from dapr.clients.grpc._request import ConversationInput +from dapr.clients.grpc._response import ConversationResponse +from typing import Dict, Any, List +from pydantic import model_validator + +import logging + +logger = logging.getLogger(__name__) + +class DaprInferenceClient: + def translate_to_json(self, response: ConversationResponse) -> dict: + response_dict = { + "outputs": [ + { + "result": output.result, + } + for output in response.outputs + ] + } + + return response_dict + + def chat_completion(self, llm: str, conversation_inputs: List[ConversationInput], scrub_pii: bool | None = None, temperature: float | None = None) -> Any: + with DaprClient() as client: + response = client.converse_alpha1(name=llm, inputs=conversation_inputs, scrub_pii=scrub_pii, temperature=temperature) + output = self.translate_to_json(response) + + return output + + +class DaprInferenceClientBase(LLMClientBase): + """ + Base class for managing Dapr Inference API clients. + Handles client initialization, configuration, and shared logic. + """ + @model_validator(mode="before") + def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]: + return values + + def model_post_init(self, __context: Any) -> None: + """ + Initializes private attributes after validation. + """ + self._provider = "dapr" + + # Set up the private config and client attributes + self._config = self.get_config() + self._client = self.get_client() + return super().model_post_init(__context) + + def get_config(self) -> DaprInferenceClientConfig: + """ + Returns the appropriate configuration for the Dapr Conversation API. + """ + return DaprInferenceClientConfig() + + def get_client(self) -> DaprInferenceClient: + """ + Initializes and returns the Dapr Inference client. + """ + return DaprInferenceClient() + + @classmethod + def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500): + """ + Initializes the DaprInferenceClientBase using DaprInferenceClientConfig. + + Args: + client_options: The configuration options for the client. + timeout: Timeout for requests (default is 1500 seconds). + + Returns: + DaprInferenceClientBase: The initialized client instance. + """ + return cls() + + @property + def config(self) -> Dict[str, Any]: + return self._config + + @property + def client(self) -> DaprInferenceClient: + return self._client diff --git a/dapr_agents/types/llm.py b/dapr_agents/types/llm.py index 770c6041..2a2fc5e1 100644 --- a/dapr_agents/types/llm.py +++ b/dapr_agents/types/llm.py @@ -31,6 +31,14 @@ def none_to_default(cls, v): raise PydanticUseDefault() return v +class DaprInferenceClientConfig: + @field_validator("*", mode="before") + @classmethod + def none_to_default(cls, v): + if v is None: + raise PydanticUseDefault() + return v + class HFInferenceClientConfig(BaseModel): model: Optional[str] = Field(None, description="Model ID on Hugging Face Hub or URL to a deployed Inference Endpoint. Defaults to a recommended model if not provided.") api_key: Optional[Union[str, bool]] = Field(None, description="Hugging Face API key for authentication. Defaults to the locally saved token. Pass False to skip token.") diff --git a/requirements.txt b/requirements.txt index af31705d..ee4f7fd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ openapi-schema-pydantic==1.2.4 regex>=2023.12.25 Jinja2==3.1.5 azure-identity==1.19.0 -dapr==1.14.0 +dapr==1.15.0rc3 dapr-ext-fastapi==1.14.0 dapr-ext-workflow==0.5.0 colorama==0.4.6