-
Notifications
You must be signed in to change notification settings - Fork 78
Add Dapr's Conversation API as LLM inference provider #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .chat import DaprChatClient | ||
| from .client import DaprInferenceClientBase |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| from dapr_agents.llm.dapr.client import DaprInferenceClientBase | ||
| from dapr_agents.llm.utils import RequestHandler, ResponseHandler | ||
| from dapr_agents.prompt.prompty import Prompty | ||
| from dapr_agents.types.message import BaseMessage | ||
| from dapr_agents.llm.chat import ChatClientBase | ||
| from dapr_agents.tool import AgentTool | ||
| from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type | ||
| from pydantic import BaseModel | ||
| from pathlib import Path | ||
| import logging | ||
| import os | ||
| import time | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class DaprChatClient(DaprInferenceClientBase, ChatClientBase): | ||
| """ | ||
| Concrete class for Dapr's chat completion API using the Inference API. | ||
| This class extends the ChatClientBase. | ||
| """ | ||
|
|
||
| def model_post_init(self, __context: Any) -> None: | ||
| """ | ||
| Initializes private attributes for provider, api, config, and client after validation. | ||
| """ | ||
| # Set the private provider and api attributes | ||
| self._api = "chat" | ||
| self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT'] | ||
|
|
||
| return super().model_post_init(__context) | ||
|
|
||
| @classmethod | ||
| def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient': | ||
| """ | ||
| Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content. | ||
|
|
||
| Args: | ||
| prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file | ||
| or inline Prompty content as a string. | ||
| timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds. | ||
|
|
||
| Returns: | ||
| DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source. | ||
| """ | ||
| # Load the Prompty instance from the provided source | ||
| prompty_instance = Prompty.load(prompty_source) | ||
|
|
||
| # Generate the prompt template from the Prompty instance | ||
| prompt_template = Prompty.to_prompt_template(prompty_instance) | ||
|
|
||
| # Extract the model configuration from Prompty | ||
| model_config = prompty_instance.model | ||
|
|
||
| # Initialize the DaprChatClient based on the Prompty model configuration | ||
| return cls.model_validate({ | ||
| 'timeout': timeout, | ||
| 'prompty': prompty_instance, | ||
| 'prompt_template': prompt_template, | ||
| }) | ||
|
|
||
| def translate_response(self, response: dict, model: str) -> dict: | ||
| """Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion.""" | ||
| choices = [ | ||
| { | ||
| "finish_reason": "stop", | ||
| "index": i, | ||
| "message": { | ||
| "content": output["result"], | ||
| "role": "assistant" | ||
| }, | ||
| "logprobs": None | ||
| } | ||
| for i, output in enumerate(response.get("outputs", [])) | ||
| ] | ||
|
|
||
| return { | ||
| "choices": choices, | ||
| "created": int(time.time()), | ||
| "model": model, | ||
| "object": "chat.completion", | ||
| "usage": {"total_tokens": len(response.get("outputs", []))} | ||
| } | ||
|
|
||
| def generate( | ||
| self, | ||
| messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None, | ||
| input_data: Optional[Dict[str, Any]] = None, | ||
| llm_component: Optional[str] = None, | ||
| tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None, | ||
| response_model: Optional[Type[BaseModel]] = None, | ||
| scrubPII: Optional[bool] = False, | ||
| **kwargs | ||
| ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]: | ||
| """ | ||
| Generate chat completions based on provided messages or input_data for prompt templates. | ||
|
|
||
| Args: | ||
| messages (Optional): Either pre-set messages or None if using input_data. | ||
| input_data (Optional[Dict[str, Any]]): Input variables for prompt templates. | ||
| llm_component (str): Name of the LLM component to use for the request. | ||
| tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request. | ||
| response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing. | ||
| scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM. | ||
| **kwargs: Additional parameters for the language model. | ||
|
|
||
| Returns: | ||
| Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s). | ||
| """ | ||
|
|
||
| # If input_data is provided, check for a prompt_template | ||
| if input_data: | ||
| if not self.prompt_template: | ||
| raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.") | ||
|
|
||
| logger.info("Using prompt template to generate messages.") | ||
| messages = self.prompt_template.format_prompt(**input_data) | ||
|
|
||
| # Ensure we have messages at this point | ||
| if not messages: | ||
| raise ValueError("Either 'messages' or 'input_data' must be provided.") | ||
|
|
||
| # Process and normalize the messages | ||
| params = {'inputs': RequestHandler.normalize_chat_messages(messages)} | ||
| # Merge Prompty parameters if available, then override with any explicit kwargs | ||
| if self.prompty: | ||
| params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs} | ||
| else: | ||
| params.update(kwargs) | ||
|
|
||
| params['scrubPII'] = scrubPII | ||
|
|
||
| # Prepare and send the request | ||
| params = RequestHandler.process_params(params, llm_provider=self.provider, tools=tools, response_model=response_model) | ||
|
|
||
| try: | ||
| logger.info("Invoking the Dapr Conversation API.") | ||
| response = self.client.chat_completion(llm_component or self._llm_component, params) | ||
| transposed_response = self.translate_response(response, self._llm_component) | ||
| logger.info("Chat completion retrieved successfully.") | ||
|
|
||
| return ResponseHandler.process_response(transposed_response, llm_provider=self.provider, response_model=response_model, stream=params.get('stream', False)) | ||
| except Exception as e: | ||
| logger.error(f"An error occurred during the Dapr Conversation API call: {e}") | ||
| raise | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| from dapr_agents.types.llm import DaprInferenceClientConfig | ||
| from dapr_agents.llm.base import LLMClientBase | ||
| from typing import Optional, Dict, Any, List | ||
| from pydantic import Field, model_validator | ||
| import os | ||
| import logging | ||
| import requests | ||
| import json | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class DaprClient: | ||
| def __init__(self): | ||
| self._dapr_endpoint = os.getenv('DAPR_BASE_URL', 'http://localhost') + ':' + os.getenv( | ||
| 'DAPR_HTTP_PORT', '3500') | ||
|
|
||
| def chat_completion(self, llm: str, request: List[Dict]) -> Any: | ||
| # Invoke Dapr | ||
| result = requests.post( | ||
| url='%s/v1.0-alpha1/conversation/%s/converse' % (self._dapr_endpoint, llm), | ||
| data=json.dumps(request) | ||
| ) | ||
|
|
||
| return result.json() | ||
|
|
||
| class DaprInferenceClientBase(LLMClientBase): | ||
| """ | ||
| Base class for managing Dapr Inference API clients. | ||
| Handles client initialization, configuration, and shared logic. | ||
| """ | ||
| @model_validator(mode="before") | ||
| def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
| return values | ||
|
|
||
| def model_post_init(self, __context: Any) -> None: | ||
| """ | ||
| Initializes private attributes after validation. | ||
| """ | ||
| self._provider = "dapr" | ||
|
|
||
| # Set up the private config and client attributes | ||
| self._config = self.get_config() | ||
| self._client = self.get_client() | ||
| return super().model_post_init(__context) | ||
|
|
||
| def get_config(self) -> DaprInferenceClientConfig: | ||
| """ | ||
| Returns the appropriate configuration for the Dapr Conversation API. | ||
| """ | ||
| return DaprInferenceClientConfig() | ||
|
|
||
| def get_client(self) -> DaprClient: | ||
| """ | ||
| Initializes and returns the Dapr Inference client. | ||
| """ | ||
| config: DaprInferenceClientConfig = self.config | ||
|
||
| return DaprClient() | ||
|
|
||
| @classmethod | ||
| def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500): | ||
| """ | ||
| Initializes the DaprInferenceClientBase using DaprInferenceClientConfig. | ||
|
|
||
| Args: | ||
| client_options: The configuration options for the client. | ||
| timeout: Timeout for requests (default is 1500 seconds). | ||
|
|
||
| Returns: | ||
| DaprInferenceClientBase: The initialized client instance. | ||
| """ | ||
| return cls() | ||
|
|
||
| @property | ||
| def config(self) -> Dict[str, Any]: | ||
| return self._config | ||
|
|
||
| @property | ||
| def client(self) -> DaprClient: | ||
| return self._client | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed