Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dapr_agents/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .nvidia.client import NVIDIAClientBase
from .nvidia.chat import NVIDIAChatClient
from .nvidia.embeddings import NVIDIAEmbeddingClient
from .elevenlabs import ElevenLabsSpeechClient
from .elevenlabs import ElevenLabsSpeechClient
from .dapr import DaprChatClient
2 changes: 2 additions & 0 deletions dapr_agents/llm/dapr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .chat import DaprChatClient
from .client import DaprInferenceClientBase
144 changes: 144 additions & 0 deletions dapr_agents/llm/dapr/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
from dapr_agents.prompt.prompty import Prompty
from dapr_agents.types.message import BaseMessage
from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.tool import AgentTool
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type
from pydantic import BaseModel
from pathlib import Path
import logging
import os
import time

logger = logging.getLogger(__name__)

class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
"""
Concrete class for Dapr's chat completion API using the Inference API.
This class extends the ChatClientBase.
"""

def model_post_init(self, __context: Any) -> None:
"""
Initializes private attributes for provider, api, config, and client after validation.
"""
# Set the private provider and api attributes
self._api = "chat"
self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT']

return super().model_post_init(__context)

@classmethod
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient':
"""
Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.

Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.

Returns:
DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source.
"""
# Load the Prompty instance from the provided source
prompty_instance = Prompty.load(prompty_source)

# Generate the prompt template from the Prompty instance
prompt_template = Prompty.to_prompt_template(prompty_instance)

# Extract the model configuration from Prompty
model_config = prompty_instance.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


# Initialize the DaprChatClient based on the Prompty model configuration
return cls.model_validate({
'timeout': timeout,
'prompty': prompty_instance,
'prompt_template': prompt_template,
})

def translate_response(self, response: dict, model: str) -> dict:
"""Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
choices = [
{
"finish_reason": "stop",
"index": i,
"message": {
"content": output["result"],
"role": "assistant"
},
"logprobs": None
}
for i, output in enumerate(response.get("outputs", []))
]

return {
"choices": choices,
"created": int(time.time()),
"model": model,
"object": "chat.completion",
"usage": {"total_tokens": len(response.get("outputs", []))}
}

def generate(
self,
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
input_data: Optional[Dict[str, Any]] = None,
llm_component: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
response_model: Optional[Type[BaseModel]] = None,
scrubPII: Optional[bool] = False,
**kwargs
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
"""
Generate chat completions based on provided messages or input_data for prompt templates.

Args:
messages (Optional): Either pre-set messages or None if using input_data.
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
llm_component (str): Name of the LLM component to use for the request.
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing.
scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM.
**kwargs: Additional parameters for the language model.

Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
"""

# If input_data is provided, check for a prompt_template
if input_data:
if not self.prompt_template:
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")

logger.info("Using prompt template to generate messages.")
messages = self.prompt_template.format_prompt(**input_data)

# Ensure we have messages at this point
if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")

# Process and normalize the messages
params = {'inputs': RequestHandler.normalize_chat_messages(messages)}
# Merge Prompty parameters if available, then override with any explicit kwargs
if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else:
params.update(kwargs)

params['scrubPII'] = scrubPII

# Prepare and send the request
params = RequestHandler.process_params(params, llm_provider=self.provider, tools=tools, response_model=response_model)

try:
logger.info("Invoking the Dapr Conversation API.")
response = self.client.chat_completion(llm_component or self._llm_component, params)
transposed_response = self.translate_response(response, self._llm_component)
logger.info("Chat completion retrieved successfully.")

return ResponseHandler.process_response(transposed_response, llm_provider=self.provider, response_model=response_model, stream=params.get('stream', False))
except Exception as e:
logger.error(f"An error occurred during the Dapr Conversation API call: {e}")
raise
79 changes: 79 additions & 0 deletions dapr_agents/llm/dapr/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dapr_agents.types.llm import DaprInferenceClientConfig
from dapr_agents.llm.base import LLMClientBase
from typing import Optional, Dict, Any, List
from pydantic import Field, model_validator
import os
import logging
import requests
import json

logger = logging.getLogger(__name__)

class DaprClient:
def __init__(self):
self._dapr_endpoint = os.getenv('DAPR_BASE_URL', 'http://localhost') + ':' + os.getenv(
'DAPR_HTTP_PORT', '3500')

def chat_completion(self, llm: str, request: List[Dict]) -> Any:
# Invoke Dapr
result = requests.post(
url='%s/v1.0-alpha1/conversation/%s/converse' % (self._dapr_endpoint, llm),
data=json.dumps(request)
)

return result.json()

class DaprInferenceClientBase(LLMClientBase):
"""
Base class for managing Dapr Inference API clients.
Handles client initialization, configuration, and shared logic.
"""
@model_validator(mode="before")
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

def model_post_init(self, __context: Any) -> None:
"""
Initializes private attributes after validation.
"""
self._provider = "dapr"

# Set up the private config and client attributes
self._config = self.get_config()
self._client = self.get_client()
return super().model_post_init(__context)

def get_config(self) -> DaprInferenceClientConfig:
"""
Returns the appropriate configuration for the Dapr Conversation API.
"""
return DaprInferenceClientConfig()

def get_client(self) -> DaprClient:
"""
Initializes and returns the Dapr Inference client.
"""
config: DaprInferenceClientConfig = self.config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config var doesn't seem to be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

return DaprClient()

@classmethod
def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500):
"""
Initializes the DaprInferenceClientBase using DaprInferenceClientConfig.

Args:
client_options: The configuration options for the client.
timeout: Timeout for requests (default is 1500 seconds).

Returns:
DaprInferenceClientBase: The initialized client instance.
"""
return cls()

@property
def config(self) -> Dict[str, Any]:
return self._config

@property
def client(self) -> DaprClient:
return self._client
8 changes: 8 additions & 0 deletions dapr_agents/types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def none_to_default(cls, v):
raise PydanticUseDefault()
return v

class DaprInferenceClientConfig:
@field_validator("*", mode="before")
@classmethod
def none_to_default(cls, v):
if v is None:
raise PydanticUseDefault()
return v

class HFInferenceClientConfig(BaseModel):
model: Optional[str] = Field(None, description="Model ID on Hugging Face Hub or URL to a deployed Inference Endpoint. Defaults to a recommended model if not provided.")
api_key: Optional[Union[str, bool]] = Field(None, description="Hugging Face API key for authentication. Defaults to the locally saved token. Pass False to skip token.")
Expand Down