Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 136 additions & 6 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from time import sleep
from unittest.mock import MagicMock, call

import pytest
Expand Down Expand Up @@ -60,6 +61,43 @@ def test_tracks_duration(client: LDClient):
assert tracker.get_summary().duration == 100


def test_tracks_duration_of(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_duration_of(lambda: sleep(0.01))

calls = client.track.mock_calls # type: ignore

assert len(calls) == 1
assert calls[0].args[0] == '$ld:ai:duration:total'
assert calls[0].args[1] == context
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_duration_of_with_exception(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

def sleep_and_throw():
sleep(0.01)
raise ValueError("Something went wrong")

try:
tracker.track_duration_of(sleep_and_throw)
assert False, "Should have thrown an exception"
except ValueError:
pass

calls = client.track.mock_calls # type: ignore

assert len(calls) == 1
assert calls[0].args[0] == '$ld:ai:duration:total'
assert calls[0].args[1] == context
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_token_usage(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
Expand Down Expand Up @@ -97,6 +135,7 @@ def test_tracks_bedrock_metrics(client: LDClient):

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
Expand All @@ -110,6 +149,39 @@ def test_tracks_bedrock_metrics(client: LDClient):
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_bedrock_metrics_with_error(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

bedrock_result = {
'$metadata': {'httpStatusCode': 500},
'usage': {
'totalTokens': 330,
'inputTokens': 220,
'outputTokens': 110,
},
'metrics': {
'latencyMs': 50,
}
}
tracker.track_bedrock_converse_metrics(bedrock_result)

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False
assert tracker.get_summary().duration == 50
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
Expand All @@ -129,6 +201,8 @@ def to_dict(self):
tracker.track_openai_metrics(lambda: Result())

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
Expand All @@ -139,6 +213,29 @@ def to_dict(self):
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics_with_exception(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

def raise_exception():
raise ValueError("Something went wrong")

try:
tracker.track_openai_metrics(raise_exception)
assert False, "Should have thrown an exception"
except ValueError:
pass

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls, any_order=False) # type: ignore

assert tracker.get_summary().usage is None


@pytest.mark.parametrize(
"kind,label",
[
Expand Down Expand Up @@ -166,11 +263,44 @@ def test_tracks_success(client: LDClient):
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_success()

client.track.assert_called_with( # type: ignore
'$ld:ai:generation',
context,
{'variationKey': 'variation-key', 'configKey': 'config-key'},
1
)
calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is True


def test_tracks_error(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_error()

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False


def test_error_overwrites_success(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_success()
tracker.track_error()

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False
58 changes: 49 additions & 9 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,20 @@ def track_duration_of(self, func):
"""
Automatically track the duration of an AI operation.

An exception occurring during the execution of the function will still
track the duration. The exception will be re-thrown.

:param func: Function to track.
:return: Result of the tracked function.
"""
start_time = time.time()
result = func()
end_time = time.time()
duration = int((end_time - start_time) * 1000) # duration in milliseconds
self.track_duration(duration)
try:
result = func()
finally:
end_time = time.time()
duration = int((end_time - start_time) * 1000) # duration in milliseconds
self.track_duration(duration)

return result

def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None:
Expand Down Expand Up @@ -146,32 +152,66 @@ def track_success(self) -> None:
self._ld_client.track(
'$ld:ai:generation', self._context, self.__get_track_data(), 1
)
self._ld_client.track(
'$ld:ai:generation:success', self._context, self.__get_track_data(), 1
)

def track_error(self) -> None:
"""
Track an unsuccessful AI generation attempt.
"""
self._summary._success = False
self._ld_client.track(
'$ld:ai:generation', self._context, self.__get_track_data(), 1
)
self._ld_client.track(
'$ld:ai:generation:error', self._context, self.__get_track_data(), 1
)

def track_openai_metrics(self, func):
"""
Track OpenAI-specific operations.

This function will track the duration of the operation, the token
usage, and the success or error status.

If the provided function throws, then this method will also throw.

In the case the provided function throws, this function will record the
duration and an error.

A failed operation will not have any token usage data.

:param func: Function to track.
:return: Result of the tracked function.
"""
result = self.track_duration_of(func)
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
try:
result = self.track_duration_of(func)
self.track_success()
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
except Exception:
self.track_error()
raise

return result

def track_bedrock_converse_metrics(self, res: dict) -> dict:
"""
Track AWS Bedrock conversation operations.


This function will track the duration of the operation, the token
usage, and the success or error status.

:param res: Response dictionary from Bedrock.
:return: The original response dictionary.
"""
status_code = res.get('$metadata', {}).get('httpStatusCode', 0)
if status_code == 200:
self.track_success()
elif status_code >= 400:
# Potentially add error tracking in the future.
pass
self.track_error()
if res.get('metrics', {}).get('latencyMs'):
self.track_duration(res['metrics']['latencyMs'])
if res.get('usage'):
Expand Down
Loading