1+ from dapr_agents .llm .dapr .client import DaprInferenceClientBase
2+ from dapr_agents .llm .utils import RequestHandler , ResponseHandler
3+ from dapr_agents .prompt .prompty import Prompty
4+ from dapr_agents .types .message import BaseMessage
5+ from dapr_agents .llm .chat import ChatClientBase
6+ from dapr_agents .tool import AgentTool
7+ from dapr .clients .grpc ._request import ConversationInput
8+ from typing import Union , Optional , Iterable , Dict , Any , List , Iterator , Type
9+ from pydantic import BaseModel
10+ from pathlib import Path
11+ import logging
12+ import os
13+ import time
14+
15+ logger = logging .getLogger (__name__ )
16+
17+ class DaprChatClient (DaprInferenceClientBase , ChatClientBase ):
18+ """
19+ Concrete class for Dapr's chat completion API using the Inference API.
20+ This class extends the ChatClientBase.
21+ """
22+
23+ def model_post_init (self , __context : Any ) -> None :
24+ """
25+ Initializes private attributes for provider, api, config, and client after validation.
26+ """
27+ # Set the private provider and api attributes
28+ self ._api = "chat"
29+ self ._llm_component = os .environ ['DAPR_LLM_COMPONENT_DEFAULT' ]
30+
31+ return super ().model_post_init (__context )
32+
33+ @classmethod
34+ def from_prompty (cls , prompty_source : Union [str , Path ], timeout : Union [int , float , Dict [str , Any ]] = 1500 ) -> 'DaprChatClient' :
35+ """
36+ Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
37+
38+ Args:
39+ prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
40+ or inline Prompty content as a string.
41+ timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
42+
43+ Returns:
44+ DaprChatClient: An instance of DaprChatClient configured with the model settings from the Prompty source.
45+ """
46+ # Load the Prompty instance from the provided source
47+ prompty_instance = Prompty .load (prompty_source )
48+
49+ # Generate the prompt template from the Prompty instance
50+ prompt_template = Prompty .to_prompt_template (prompty_instance )
51+
52+ # Initialize the DaprChatClient based on the Prompty model configuration
53+ return cls .model_validate ({
54+ 'timeout' : timeout ,
55+ 'prompty' : prompty_instance ,
56+ 'prompt_template' : prompt_template ,
57+ })
58+
59+ def translate_response (self , response : dict , model : str ) -> dict :
60+ """Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
61+ choices = [
62+ {
63+ "finish_reason" : "stop" ,
64+ "index" : i ,
65+ "message" : {
66+ "content" : output ["result" ],
67+ "role" : "assistant"
68+ },
69+ "logprobs" : None
70+ }
71+ for i , output in enumerate (response .get ("outputs" , []))
72+ ]
73+
74+ return {
75+ "choices" : choices ,
76+ "created" : int (time .time ()),
77+ "model" : model ,
78+ "object" : "chat.completion" ,
79+ "usage" : {"total_tokens" : "-1" }
80+ }
81+
82+ def convert_to_conversation_inputs (self , inputs : List [Dict [str , Any ]]) -> List [ConversationInput ]:
83+ return [
84+ ConversationInput (
85+ content = item ["content" ],
86+ role = item .get ("role" ),
87+ scrub_pii = item .get ("scrubPII" ) == "true"
88+ )
89+ for item in inputs
90+ ]
91+
92+ def generate (
93+ self ,
94+ messages : Union [str , Dict [str , Any ], BaseMessage , Iterable [Union [Dict [str , Any ], BaseMessage ]]] = None ,
95+ input_data : Optional [Dict [str , Any ]] = None ,
96+ llm_component : Optional [str ] = None ,
97+ tools : Optional [List [Union [AgentTool , Dict [str , Any ]]]] = None ,
98+ response_model : Optional [Type [BaseModel ]] = None ,
99+ scrubPII : Optional [bool ] = False ,
100+ temperature : Optional [float ] = None ,
101+ ** kwargs
102+ ) -> Union [Iterator [Dict [str , Any ]], Dict [str , Any ]]:
103+ """
104+ Generate chat completions based on provided messages or input_data for prompt templates.
105+
106+ Args:
107+ messages (Optional): Either pre-set messages or None if using input_data.
108+ input_data (Optional[Dict[str, Any]]): Input variables for prompt templates.
109+ llm_component (str): Name of the LLM component to use for the request.
110+ tools (List[Union[AgentTool, Dict[str, Any]]]): List of tools for the request.
111+ response_model (Type[BaseModel]): Optional Pydantic model for structured response parsing.
112+ scrubPII (Type[bool]): Optional flag to obfuscate any sensitive information coming back from the LLM.
113+ **kwargs: Additional parameters for the language model.
114+
115+ Returns:
116+ Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
117+ """
118+
119+ # If input_data is provided, check for a prompt_template
120+ if input_data :
121+ if not self .prompt_template :
122+ raise ValueError ("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data." )
123+
124+ logger .info ("Using prompt template to generate messages." )
125+ messages = self .prompt_template .format_prompt (** input_data )
126+
127+ # Ensure we have messages at this point
128+ if not messages :
129+ raise ValueError ("Either 'messages' or 'input_data' must be provided." )
130+
131+ # Process and normalize the messages
132+ params = {'inputs' : RequestHandler .normalize_chat_messages (messages )}
133+ # Merge Prompty parameters if available, then override with any explicit kwargs
134+ if self .prompty :
135+ params = {** self .prompty .model .parameters .model_dump (), ** params , ** kwargs }
136+ else :
137+ params .update (kwargs )
138+
139+ # Prepare and send the request
140+ params = RequestHandler .process_params (params , llm_provider = self .provider , tools = tools , response_model = response_model )
141+ inputs = self .convert_to_conversation_inputs (params ['inputs' ])
142+
143+ try :
144+ logger .info ("Invoking the Dapr Conversation API." )
145+ response = self .client .chat_completion (llm = llm_component or self ._llm_component , conversation_inputs = inputs , scrub_pii = scrubPII , temperature = temperature )
146+ transposed_response = self .translate_response (response , self ._llm_component )
147+ logger .info ("Chat completion retrieved successfully." )
148+
149+ return ResponseHandler .process_response (transposed_response , llm_provider = self .provider , response_model = response_model , stream = params .get ('stream' , False ))
150+ except Exception as e :
151+ logger .error (f"An error occurred during the Dapr Conversation API call: { e } " )
152+ raise
0 commit comments