Skip to content

Commit c83fe3b

Browse files
authored
Add Dapr's Conversation API as LLM inference provider (#8)
* add dapr conversation api as llm inference Signed-off-by: yaron2 <[email protected]> * add response parsing Signed-off-by: yaron2 <[email protected]> * replace http request with dapr client Signed-off-by: yaron2 <[email protected]> * address review comments Signed-off-by: yaron2 <[email protected]> --------- Signed-off-by: yaron2 <[email protected]>
1 parent 6a08836 commit c83fe3b

File tree

6 files changed

+251
-2
lines changed

6 files changed

+251
-2
lines changed

dapr_agents/llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .nvidia.client import NVIDIAClientBase
1010
from .nvidia.chat import NVIDIAChatClient
1111
from .nvidia.embeddings import NVIDIAEmbeddingClient
12-
from .elevenlabs import ElevenLabsSpeechClient
12+
from .elevenlabs import ElevenLabsSpeechClient
13+
from .dapr import DaprChatClient

dapr_agents/llm/dapr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .chat import DaprChatClient
2+
from .client import DaprInferenceClientBase

dapr_agents/llm/dapr/chat.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from dapr_agents.llm.dapr.client import DaprInferenceClientBase
2+
from dapr_agents.llm.utils import RequestHandler, ResponseHandler
3+
from dapr_agents.prompt.prompty import Prompty
4+
from dapr_agents.types.message import BaseMessage
5+
from dapr_agents.llm.chat import ChatClientBase
6+
from dapr_agents.tool import AgentTool
7+
from dapr.clients.grpc._request import ConversationInput
8+
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type
9+
from pydantic import BaseModel
10+
from pathlib import Path
11+
import logging
12+
import os
13+
import time
14+
15+
logger = logging.getLogger(__name__)
16+
17+
class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
18+
"""
19+
Concrete class for Dapr's chat completion API using the Inference API.
20+
This class extends the ChatClientBase.
21+
"""
22+
23+
def model_post_init(self, __context: Any) -> None:
24+
"""
25+
Initializes private attributes for provider, api, config, and client after validation.
26+
"""
27+
# Set the private provider and api attributes
28+
self._api = "chat"
29+
self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT']
30+
31+
return super().model_post_init(__context)
32+
33+
@classmethod
34+
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient':
35+
"""
36+
Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
37+
38+
Args:
39+
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
40+
or inline Prompty content as a string.
41+
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
42+
43+
Returns:
44+
DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source.
45+
"""
46+
# Load the Prompty instance from the provided source
47+
prompty_instance = Prompty.load(prompty_source)
48+
49+
# Generate the prompt template from the Prompty instance
50+
prompt_template = Prompty.to_prompt_template(prompty_instance)
51+
52+
# Initialize the DaprChatClient based on the Prompty model configuration
53+
return cls.model_validate({
54+
'timeout': timeout,
55+
'prompty': prompty_instance,
56+
'prompt_template': prompt_template,
57+
})
58+
59+
def translate_response(self, response: dict, model: str) -> dict:
60+
"""Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
61+
choices = [
62+
{
63+
"finish_reason": "stop",
64+
"index": i,
65+
"message": {
66+
"content": output["result"],
67+
"role": "assistant"
68+
},
69+
"logprobs": None
70+
}
71+
for i, output in enumerate(response.get("outputs", []))
72+
]
73+
74+
return {
75+
"choices": choices,
76+
"created": int(time.time()),
77+
"model": model,
78+
"object": "chat.completion",
79+
"usage": {"total_tokens": "-1"}
80+
}
81+
82+
def convert_to_conversation_inputs(self, inputs: List[Dict[str, Any]]) -> List[ConversationInput]:
83+
return [
84+
ConversationInput(
85+
content=item["content"],
86+
role=item.get("role"),
87+
scrub_pii=item.get("scrubPII") == "true"
88+
)
89+
for item in inputs
90+
]
91+
92+
def generate(
93+
self,
94+
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
95+
input_data: Optional[Dict[str, Any]] = None,
96+
llm_component: Optional[str] = None,
97+
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
98+
response_model: Optional[Type[BaseModel]] = None,
99+
scrubPII: Optional[bool] = False,
100+
temperature: Optional[float] = None,
101+
**kwargs
102+
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
103+
"""
104+
Generate chat completions based on provided messages or input_data for prompt templates.
105+
106+
Args:
107+
messages (Optional): Either pre-set messages or None if using input_data.
108+
input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
109+
llm_component (str): Name of the LLM component to use for the request.
110+
tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
111+
response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing.
112+
scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM.
113+
**kwargs: Additional parameters for the language model.
114+
115+
Returns:
116+
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
117+
"""
118+
119+
# If input_data is provided, check for a prompt_template
120+
if input_data:
121+
if not self.prompt_template:
122+
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
123+
124+
logger.info("Using prompt template to generate messages.")
125+
messages = self.prompt_template.format_prompt(**input_data)
126+
127+
# Ensure we have messages at this point
128+
if not messages:
129+
raise ValueError("Either 'messages' or 'input_data' must be provided.")
130+
131+
# Process and normalize the messages
132+
params = {'inputs': RequestHandler.normalize_chat_messages(messages)}
133+
# Merge Prompty parameters if available, then override with any explicit kwargs
134+
if self.prompty:
135+
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
136+
else:
137+
params.update(kwargs)
138+
139+
# Prepare and send the request
140+
params = RequestHandler.process_params(params, llm_provider=self.provider, tools=tools, response_model=response_model)
141+
inputs = self.convert_to_conversation_inputs(params['inputs'])
142+
143+
try:
144+
logger.info("Invoking the Dapr Conversation API.")
145+
response = self.client.chat_completion(llm=llm_component or self._llm_component, conversation_inputs=inputs, scrub_pii=scrubPII, temperature=temperature)
146+
transposed_response = self.translate_response(response, self._llm_component)
147+
logger.info("Chat completion retrieved successfully.")
148+
149+
return ResponseHandler.process_response(transposed_response, llm_provider=self.provider, response_model=response_model, stream=params.get('stream', False))
150+
except Exception as e:
151+
logger.error(f"An error occurred during the Dapr Conversation API call: {e}")
152+
raise

dapr_agents/llm/dapr/client.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from dapr_agents.types.llm import DaprInferenceClientConfig
2+
from dapr_agents.llm.base import LLMClientBase
3+
from dapr.clients import DaprClient
4+
from dapr.clients.grpc._request import ConversationInput
5+
from dapr.clients.grpc._response import ConversationResponse
6+
from typing import Dict, Any, List
7+
from pydantic import model_validator
8+
9+
import logging
10+
11+
logger = logging.getLogger(__name__)
12+
13+
class DaprInferenceClient:
14+
def translate_to_json(self, response: ConversationResponse) -> dict:
15+
response_dict = {
16+
"outputs": [
17+
{
18+
"result": output.result,
19+
}
20+
for output in response.outputs
21+
]
22+
}
23+
24+
return response_dict
25+
26+
def chat_completion(self, llm: str, conversation_inputs: List[ConversationInput], scrub_pii: bool | None = None, temperature: float | None = None) -> Any:
27+
with DaprClient() as client:
28+
response = client.converse_alpha1(name=llm, inputs=conversation_inputs, scrub_pii=scrub_pii, temperature=temperature)
29+
output = self.translate_to_json(response)
30+
31+
return output
32+
33+
34+
class DaprInferenceClientBase(LLMClientBase):
35+
"""
36+
Base class for managing Dapr Inference API clients.
37+
Handles client initialization, configuration, and shared logic.
38+
"""
39+
@model_validator(mode="before")
40+
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
41+
return values
42+
43+
def model_post_init(self, __context: Any) -> None:
44+
"""
45+
Initializes private attributes after validation.
46+
"""
47+
self._provider = "dapr"
48+
49+
# Set up the private config and client attributes
50+
self._config = self.get_config()
51+
self._client = self.get_client()
52+
return super().model_post_init(__context)
53+
54+
def get_config(self) -> DaprInferenceClientConfig:
55+
"""
56+
Returns the appropriate configuration for the Dapr Conversation API.
57+
"""
58+
return DaprInferenceClientConfig()
59+
60+
def get_client(self) -> DaprInferenceClient:
61+
"""
62+
Initializes and returns the Dapr Inference client.
63+
"""
64+
return DaprInferenceClient()
65+
66+
@classmethod
67+
def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500):
68+
"""
69+
Initializes the DaprInferenceClientBase using DaprInferenceClientConfig.
70+
71+
Args:
72+
client_options: The configuration options for the client.
73+
timeout: Timeout for requests (default is 1500 seconds).
74+
75+
Returns:
76+
DaprInferenceClientBase: The initialized client instance.
77+
"""
78+
return cls()
79+
80+
@property
81+
def config(self) -> Dict[str, Any]:
82+
return self._config
83+
84+
@property
85+
def client(self) -> DaprInferenceClient:
86+
return self._client

dapr_agents/types/llm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def none_to_default(cls, v):
3131
raise PydanticUseDefault()
3232
return v
3333

34+
class DaprInferenceClientConfig:
35+
@field_validator("*", mode="before")
36+
@classmethod
37+
def none_to_default(cls, v):
38+
if v is None:
39+
raise PydanticUseDefault()
40+
return v
41+
3442
class HFInferenceClientConfig(BaseModel):
3543
model: Optional[str] = Field(None, description="Model ID on Hugging Face Hub or URL to a deployed Inference Endpoint. Defaults to a recommended model if not provided.")
3644
api_key: Optional[Union[str, bool]] = Field(None, description="Hugging Face API key for authentication. Defaults to the locally saved token. Pass False to skip token.")

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ openapi-schema-pydantic==1.2.4
55
regex>=2023.12.25
66
Jinja2==3.1.5
77
azure-identity==1.19.0
8-
dapr==1.14.0
8+
dapr==1.15.0rc3
99
dapr-ext-fastapi==1.14.0
1010
dapr-ext-workflow==0.5.0
1111
colorama==0.4.6

0 commit comments

Comments
 (0)