Skip to content

Commit 0ef82ce

Browse files
committed
feat: Add track_error to mirror track_success
Additionally, emit new `$ld:ai:generation:(success|error)` events on success or failure.
1 parent 80e1845 commit 0ef82ce

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

ldai/testing/test_tracker.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def test_tracks_bedrock_metrics(client: LDClient):
9797

9898
calls = [
9999
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
100+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
100101
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
101102
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
102103
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
@@ -110,6 +111,39 @@ def test_tracks_bedrock_metrics(client: LDClient):
110111
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
111112

112113

114+
def test_tracks_bedrock_metrics_with_error(client: LDClient):
115+
context = Context.create('user-key')
116+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
117+
118+
bedrock_result = {
119+
'$metadata': {'httpStatusCode': 500},
120+
'usage': {
121+
'totalTokens': 330,
122+
'inputTokens': 220,
123+
'outputTokens': 110,
124+
},
125+
'metrics': {
126+
'latencyMs': 50,
127+
}
128+
}
129+
tracker.track_bedrock_converse_metrics(bedrock_result)
130+
131+
calls = [
132+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
133+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
134+
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
135+
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
136+
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
137+
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
138+
]
139+
140+
client.track.assert_has_calls(calls) # type: ignore
141+
142+
assert tracker.get_summary().success is False
143+
assert tracker.get_summary().duration == 50
144+
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
145+
146+
113147
def test_tracks_openai_metrics(client: LDClient):
114148
context = Context.create('user-key')
115149
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
@@ -166,11 +200,44 @@ def test_tracks_success(client: LDClient):
166200
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
167201
tracker.track_success()
168202

169-
client.track.assert_called_with( # type: ignore
170-
'$ld:ai:generation',
171-
context,
172-
{'variationKey': 'variation-key', 'configKey': 'config-key'},
173-
1
174-
)
203+
calls = [
204+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
205+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
206+
]
207+
208+
client.track.assert_has_calls(calls) # type: ignore
175209

176210
assert tracker.get_summary().success is True
211+
212+
213+
def test_tracks_error(client: LDClient):
214+
context = Context.create('user-key')
215+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
216+
tracker.track_error()
217+
218+
calls = [
219+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
220+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
221+
]
222+
223+
client.track.assert_has_calls(calls) # type: ignore
224+
225+
assert tracker.get_summary().success is False
226+
227+
228+
def test_error_overwrites_success(client: LDClient):
229+
context = Context.create('user-key')
230+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
231+
tracker.track_success()
232+
tracker.track_error()
233+
234+
calls = [
235+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
236+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
237+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
238+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
239+
]
240+
241+
client.track.assert_has_calls(calls) # type: ignore
242+
243+
assert tracker.get_summary().success is False

ldai/tracker.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,21 @@ def track_success(self) -> None:
146146
self._ld_client.track(
147147
'$ld:ai:generation', self._context, self.__get_track_data(), 1
148148
)
149+
self._ld_client.track(
150+
'$ld:ai:generation:success', self._context, self.__get_track_data(), 1
151+
)
152+
153+
def track_error(self) -> None:
154+
"""
155+
Track an unsuccessful AI generation attempt.
156+
"""
157+
self._summary._success = False
158+
self._ld_client.track(
159+
'$ld:ai:generation', self._context, self.__get_track_data(), 1
160+
)
161+
self._ld_client.track(
162+
'$ld:ai:generation:error', self._context, self.__get_track_data(), 1
163+
)
149164

150165
def track_openai_metrics(self, func):
151166
"""
@@ -170,8 +185,7 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict:
170185
if status_code == 200:
171186
self.track_success()
172187
elif status_code >= 400:
173-
# Potentially add error tracking in the future.
174-
pass
188+
self.track_error()
175189
if res.get('metrics', {}).get('latencyMs'):
176190
self.track_duration(res['metrics']['latencyMs'])
177191
if res.get('usage'):

0 commit comments

Comments
 (0)