Skip to content

Commit 859f395

Browse files
committed
Handle exceptions for track duration of and openai metrics
1 parent 0ef82ce commit 859f395

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

ldai/testing/test_tracker.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import MagicMock, call
22

33
import pytest
4+
from time import sleep
45
from ldclient import Config, Context, LDClient
56
from ldclient.integrations.test_data import TestData
67

@@ -60,6 +61,43 @@ def test_tracks_duration(client: LDClient):
6061
assert tracker.get_summary().duration == 100
6162

6263

64+
def test_tracks_duration_of(client: LDClient):
65+
context = Context.create('user-key')
66+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
67+
tracker.track_duration_of(lambda: sleep(0.01))
68+
69+
calls = client.track.mock_calls # type: ignore
70+
71+
assert len(calls) == 1
72+
assert calls[0].args[0] == '$ld:ai:duration:total'
73+
assert calls[0].args[1] == context
74+
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
75+
assert calls[0].args[3] == pytest.approx(10)
76+
77+
78+
def test_tracks_duration_of_with_exception(client: LDClient):
79+
context = Context.create('user-key')
80+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
81+
82+
def sleep_and_throw():
83+
sleep(0.01)
84+
raise ValueError("Something went wrong")
85+
86+
try:
87+
tracker.track_duration_of(sleep_and_throw)
88+
assert False, "Should have thrown an exception"
89+
except ValueError:
90+
pass
91+
92+
calls = client.track.mock_calls # type: ignore
93+
94+
assert len(calls) == 1
95+
assert calls[0].args[0] == '$ld:ai:duration:total'
96+
assert calls[0].args[1] == context
97+
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
98+
assert calls[0].args[3] == pytest.approx(10)
99+
100+
63101
def test_tracks_token_usage(client: LDClient):
64102
context = Context.create('user-key')
65103
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
@@ -163,6 +201,8 @@ def to_dict(self):
163201
tracker.track_openai_metrics(lambda: Result())
164202

165203
calls = [
204+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
205+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
166206
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
167207
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
168208
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
@@ -173,6 +213,29 @@ def to_dict(self):
173213
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
174214

175215

216+
def test_tracks_openai_metrics_with_exception(client: LDClient):
217+
context = Context.create('user-key')
218+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
219+
220+
def raise_exception():
221+
raise ValueError("Something went wrong")
222+
223+
try:
224+
tracker.track_openai_metrics(raise_exception)
225+
assert False, "Should have thrown an exception"
226+
except ValueError:
227+
pass
228+
229+
calls = [
230+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
231+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
232+
]
233+
234+
client.track.assert_has_calls(calls, any_order=False) # type: ignore
235+
236+
assert tracker.get_summary().usage is None
237+
238+
176239
@pytest.mark.parametrize(
177240
"kind,label",
178241
[

ldai/tracker.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,20 @@ def track_duration_of(self, func):
106106
"""
107107
Automatically track the duration of an AI operation.
108108
109+
An exception occurring during the execution of the function will still
110+
track the duration. The exception will be re-thrown.
111+
109112
:param func: Function to track.
110113
:return: Result of the tracked function.
111114
"""
112115
start_time = time.time()
113-
result = func()
114-
end_time = time.time()
115-
duration = int((end_time - start_time) * 1000) # duration in milliseconds
116-
self.track_duration(duration)
116+
try:
117+
result = func()
118+
finally:
119+
end_time = time.time()
120+
duration = int((end_time - start_time) * 1000) # duration in milliseconds
121+
self.track_duration(duration)
122+
117123
return result
118124

119125
def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None:
@@ -169,9 +175,15 @@ def track_openai_metrics(self, func):
169175
:param func: Function to track.
170176
:return: Result of the tracked function.
171177
"""
172-
result = self.track_duration_of(func)
173-
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
174-
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
178+
try:
179+
result = self.track_duration_of(func)
180+
self.track_success()
181+
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
182+
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
183+
except Exception:
184+
self.track_error()
185+
raise
186+
175187
return result
176188

177189
def track_bedrock_converse_metrics(self, res: dict) -> dict:

0 commit comments

Comments
 (0)