Skip to content

Commit 445ab8c

Browse files
committed
Add Chat implementation
1 parent 231ae2e commit 445ab8c

File tree

3 files changed

+304
-0
lines changed

3 files changed

+304
-0
lines changed

ldai/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
# Export judge
2828
from ldai.judge import AIJudge
2929

30+
# Export chat
31+
from ldai.chat import TrackedChat
32+
3033
# Export judge types
3134
from ldai.providers.types import EvalScore, JudgeResponse
3235

@@ -41,6 +44,7 @@
4144
'AIJudgeConfig',
4245
'AIJudgeConfigDefault',
4346
'AIJudge',
47+
'TrackedChat',
4448
'EvalScore',
4549
'JudgeConfiguration',
4650
'JudgeResponse',

ldai/chat/__init__.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""TrackedChat implementation for managing AI chat conversations."""
2+
3+
from typing import Any, Dict, List, Optional
4+
5+
from ldai.models import AICompletionConfig, LDMessage
6+
from ldai.providers.ai_provider import AIProvider
7+
from ldai.providers.types import ChatResponse, JudgeResponse
8+
from ldai.judge import AIJudge
9+
from ldai.tracker import LDAIConfigTracker
10+
11+
12+
class TrackedChat:
13+
"""
14+
Concrete implementation of TrackedChat that provides chat functionality
15+
by delegating to an AIProvider implementation.
16+
17+
This class handles conversation management and tracking, while delegating
18+
the actual model invocation to the provider.
19+
"""
20+
21+
def __init__(
22+
self,
23+
ai_config: AICompletionConfig,
24+
tracker: LDAIConfigTracker,
25+
provider: AIProvider,
26+
judges: Optional[Dict[str, AIJudge]] = None,
27+
logger: Optional[Any] = None,
28+
):
29+
"""
30+
Initialize the TrackedChat.
31+
32+
:param ai_config: The completion AI configuration
33+
:param tracker: The tracker for the completion configuration
34+
:param provider: The AI provider to use for chat
35+
:param judges: Optional dictionary of judge instances keyed by their configuration keys
36+
:param logger: Optional logger for logging
37+
"""
38+
self._ai_config = ai_config
39+
self._tracker = tracker
40+
self._provider = provider
41+
self._judges = judges or {}
42+
self._logger = logger
43+
self._messages: List[LDMessage] = []
44+
45+
async def invoke(self, prompt: str) -> ChatResponse:
46+
"""
47+
Invoke the chat model with a prompt string.
48+
49+
This method handles conversation management and tracking, delegating to the provider's invoke_model method.
50+
51+
:param prompt: The user prompt to send to the chat model
52+
:return: ChatResponse containing the model's response and metrics
53+
"""
54+
# Convert prompt string to LDMessage with role 'user' and add to conversation history
55+
user_message: LDMessage = LDMessage(role='user', content=prompt)
56+
self._messages.append(user_message)
57+
58+
# Prepend config messages to conversation history for model invocation
59+
config_messages = self._ai_config.messages or []
60+
all_messages = config_messages + self._messages
61+
62+
# Delegate to provider-specific implementation with tracking
63+
response = await self._tracker.track_metrics_of(
64+
lambda result: result.metrics,
65+
lambda: self._provider.invoke_model(all_messages),
66+
)
67+
68+
# Evaluate with judges if configured
69+
if (
70+
self._ai_config.judge_configuration
71+
and self._ai_config.judge_configuration.judges
72+
and len(self._ai_config.judge_configuration.judges) > 0
73+
):
74+
evaluations = await self._evaluate_with_judges(self._messages, response)
75+
response.evaluations = evaluations
76+
77+
# Add the response message to conversation history
78+
self._messages.append(response.message)
79+
return response
80+
81+
async def _evaluate_with_judges(
82+
self,
83+
messages: List[LDMessage],
84+
response: ChatResponse,
85+
) -> List[Optional[JudgeResponse]]:
86+
"""
87+
Evaluates the response with all configured judges.
88+
89+
Returns a list of evaluation results.
90+
91+
:param messages: Array of messages representing the conversation history
92+
:param response: The AI response to be evaluated
93+
:return: List of judge evaluation results (may contain None for failed evaluations)
94+
"""
95+
if not self._ai_config.judge_configuration or not self._ai_config.judge_configuration.judges:
96+
return []
97+
98+
judge_configs = self._ai_config.judge_configuration.judges
99+
100+
# Start all judge evaluations in parallel
101+
async def evaluate_judge(judge_config):
102+
judge = self._judges.get(judge_config.key)
103+
if not judge:
104+
if self._logger:
105+
self._logger.warn(
106+
f"Judge configuration is not enabled: {judge_config.key}",
107+
)
108+
return None
109+
110+
eval_result = await judge.evaluate_messages(
111+
messages, response, judge_config.sampling_rate
112+
)
113+
114+
if eval_result and eval_result.success:
115+
self._tracker.track_eval_scores(eval_result.evals)
116+
117+
return eval_result
118+
119+
# Ensure all evaluations complete even if some fail
120+
import asyncio
121+
evaluation_promises = [evaluate_judge(judge_config) for judge_config in judge_configs]
122+
results = await asyncio.gather(*evaluation_promises, return_exceptions=True)
123+
124+
# Map exceptions to None
125+
return [
126+
None if isinstance(result, Exception) else result
127+
for result in results
128+
]
129+
130+
def get_config(self) -> AICompletionConfig:
131+
"""
132+
Get the underlying AI configuration used to initialize this TrackedChat.
133+
134+
:return: The AI completion configuration
135+
"""
136+
return self._ai_config
137+
138+
def get_tracker(self) -> LDAIConfigTracker:
139+
"""
140+
Get the underlying AI configuration tracker used to initialize this TrackedChat.
141+
142+
:return: The tracker instance
143+
"""
144+
return self._tracker
145+
146+
def get_provider(self) -> AIProvider:
147+
"""
148+
Get the underlying AI provider instance.
149+
150+
This provides direct access to the provider for advanced use cases.
151+
152+
:return: The AI provider instance
153+
"""
154+
return self._provider
155+
156+
def get_judges(self) -> Dict[str, AIJudge]:
157+
"""
158+
Get the judges associated with this TrackedChat.
159+
160+
Returns a dictionary of judge instances keyed by their configuration keys.
161+
162+
:return: Dictionary of judge instances
163+
"""
164+
return self._judges
165+
166+
def append_messages(self, messages: List[LDMessage]) -> None:
167+
"""
168+
Append messages to the conversation history.
169+
170+
Adds messages to the conversation history without invoking the model,
171+
which is useful for managing multi-turn conversations or injecting context.
172+
173+
:param messages: Array of messages to append to the conversation history
174+
"""
175+
self._messages.extend(messages)
176+
177+
def get_messages(self, include_config_messages: bool = False) -> List[LDMessage]:
178+
"""
179+
Get all messages in the conversation history.
180+
181+
:param include_config_messages: Whether to include the config messages from the AIConfig.
182+
Defaults to False.
183+
:return: Array of messages. When include_config_messages is True, returns both config
184+
messages and conversation history with config messages prepended. When False,
185+
returns only the conversation history messages.
186+
"""
187+
if include_config_messages:
188+
config_messages = self._ai_config.messages or []
189+
return config_messages + self._messages
190+
return list(self._messages)
191+

ldai/client.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ldclient import Context
55
from ldclient.client import LDClient
66

7+
from ldai.chat import TrackedChat
78
from ldai.judge import AIJudge
89
from ldai.models import (
910
AIAgentConfig,
@@ -192,6 +193,114 @@ async def create_judge(
192193
# Would log error if logger available
193194
return None
194195

196+
async def _initialize_judges(
197+
self,
198+
judge_configs: List[JudgeConfiguration.Judge],
199+
context: Context,
200+
variables: Optional[Dict[str, Any]] = None,
201+
default_ai_provider: Optional[SupportedAIProvider] = None,
202+
) -> Dict[str, AIJudge]:
203+
"""
204+
Initialize judges from judge configurations.
205+
206+
:param judge_configs: List of judge configurations
207+
:param context: Standard Context used when evaluating flags
208+
:param variables: Dictionary of values for instruction interpolation
209+
:param default_ai_provider: Optional default AI provider to use
210+
:return: Dictionary of judge instances keyed by their configuration keys
211+
"""
212+
judges: Dict[str, AIJudge] = {}
213+
214+
async def create_judge_for_config(judge_key: str):
215+
judge = await self.create_judge(
216+
judge_key,
217+
context,
218+
AIJudgeConfigDefault(enabled=False),
219+
variables,
220+
default_ai_provider,
221+
)
222+
return judge_key, judge
223+
224+
judge_promises = [
225+
create_judge_for_config(judge_config.key)
226+
for judge_config in judge_configs
227+
]
228+
229+
import asyncio
230+
results = await asyncio.gather(*judge_promises, return_exceptions=True)
231+
232+
for result in results:
233+
if isinstance(result, Exception):
234+
continue
235+
judge_key, judge = result
236+
if judge:
237+
judges[judge_key] = judge
238+
239+
return judges
240+
241+
async def create_chat(
242+
self,
243+
key: str,
244+
context: Context,
245+
default_value: AICompletionConfigDefault,
246+
variables: Optional[Dict[str, Any]] = None,
247+
default_ai_provider: Optional[SupportedAIProvider] = None,
248+
) -> Optional[TrackedChat]:
249+
"""
250+
Creates and returns a new TrackedChat instance for AI chat conversations.
251+
252+
:param key: The key identifying the AI completion configuration to use
253+
:param context: Standard Context used when evaluating flags
254+
:param default_value: A default value representing a standard AI config result
255+
:param variables: Dictionary of values for instruction interpolation
256+
:param default_ai_provider: Optional default AI provider to use
257+
:return: TrackedChat instance or None if disabled/unsupported
258+
259+
Example::
260+
261+
chat = await client.create_chat(
262+
"customer-support-chat",
263+
context,
264+
AICompletionConfigDefault(
265+
enabled=True,
266+
model=ModelConfig("gpt-4"),
267+
provider=ProviderConfig("openai"),
268+
messages=[LDMessage(role='system', content='You are a helpful assistant.')]
269+
),
270+
variables={'customerName': 'John'}
271+
)
272+
273+
if chat:
274+
response = await chat.invoke("I need help with my order")
275+
print(response.message.content)
276+
277+
# Access conversation history
278+
messages = chat.get_messages()
279+
print(f"Conversation has {len(messages)} messages")
280+
"""
281+
self._client.track('$ld:ai:config:function:createChat', context, key, 1)
282+
283+
config = self.completion_config(key, context, default_value, variables)
284+
285+
if not config.enabled or not config.tracker:
286+
# Would log info if logger available
287+
return None
288+
289+
provider = await AIProviderFactory.create(config, None, default_ai_provider)
290+
if not provider:
291+
return None
292+
293+
judges = {}
294+
if config.judge_configuration and config.judge_configuration.judges:
295+
judges = await self._initialize_judges(
296+
config.judge_configuration.judges,
297+
context,
298+
variables,
299+
default_ai_provider,
300+
)
301+
302+
return TrackedChat(config, config.tracker, provider, judges, None)
303+
195304
def agent_config(
196305
self,
197306
key: str,

0 commit comments

Comments
 (0)