Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call

import pytest
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData

from ldai.tracker import FeedbackKind, LDAIConfigTracker
from ldai.tracker import FeedbackKind, LDAIConfigTracker, TokenUsage


@pytest.fixture
Expand Down Expand Up @@ -60,6 +60,85 @@ def test_tracks_duration(client: LDClient):
assert tracker.get_summary().duration == 100


def test_tracks_token_usage(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

tokens = TokenUsage(300, 200, 100)
tracker.track_tokens(tokens)

calls = [
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 300),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 200),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 100),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().usage == tokens


def test_tracks_bedrock_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

bedrock_result = {
'$metadata': {'httpStatusCode': 200},
'usage': {
'totalTokens': 330,
'inputTokens': 220,
'outputTokens': 110,
},
'metrics': {
'latencyMs': 50,
}
}
tracker.track_bedrock_converse_metrics(bedrock_result)

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is True
assert tracker.get_summary().duration == 50
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

class Result:
def __init__(self):
self.usage = Usage()

class Usage:
def to_dict(self):
return {
'total_tokens': 330,
'prompt_tokens': 220,
'completion_tokens': 110,
}

tracker.track_openai_metrics(lambda: Result())

calls = [
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls, any_order=False) # type: ignore

assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


@pytest.mark.parametrize(
"kind,label",
[
Expand Down
165 changes: 46 additions & 119 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional, Union
from typing import Dict, Optional

from ldclient import Context, LDClient


@dataclass
class TokenMetrics:
"""
Metrics for token usage in AI operations.

:param total: Total number of tokens used.
:param input: Number of input tokens.
:param output: Number of output tokens.
"""

total: int
input: int
output: int # type: ignore


class FeedbackKind(Enum):
"""
Types of feedback that can be provided for AI operations.
Expand All @@ -35,99 +20,14 @@ class TokenUsage:
"""
Tracks token usage for AI operations.

:param total_tokens: Total number of tokens used.
:param prompt_tokens: Number of tokens in the prompt.
:param completion_tokens: Number of tokens in the completion.
"""

total_tokens: int
prompt_tokens: int
completion_tokens: int

def to_metrics(self):
"""
Convert token usage to metrics format.

:return: Dictionary containing token metrics.
"""
return {
'total': self['total_tokens'],
'input': self['prompt_tokens'],
'output': self['completion_tokens'],
}


@dataclass
class LDOpenAIUsage:
"""
LaunchDarkly-specific OpenAI usage tracking.

:param total_tokens: Total number of tokens used.
:param prompt_tokens: Number of tokens in the prompt.
:param completion_tokens: Number of tokens in the completion.
"""

total_tokens: int
prompt_tokens: int
completion_tokens: int


@dataclass
class OpenAITokenUsage:
"""
Tracks OpenAI-specific token usage.
"""

def __init__(self, data: LDOpenAIUsage):
"""
Initialize OpenAI token usage tracking.

:param data: OpenAI usage data.
"""
self.total_tokens = data.total_tokens
self.prompt_tokens = data.prompt_tokens
self.completion_tokens = data.completion_tokens

def to_metrics(self) -> TokenMetrics:
"""
Convert OpenAI token usage to metrics format.

:return: TokenMetrics object containing usage data.
"""
return TokenMetrics(
total=self.total_tokens,
input=self.prompt_tokens,
output=self.completion_tokens,
)


@dataclass
class BedrockTokenUsage:
"""
Tracks AWS Bedrock-specific token usage.
:param total: Total number of tokens used.
:param input: Number of tokens in the prompt.
:param output: Number of tokens in the completion.
"""

def __init__(self, data: dict):
"""
Initialize Bedrock token usage tracking.

:param data: Dictionary containing Bedrock usage data.
"""
self.totalTokens = data.get('totalTokens', 0)
self.inputTokens = data.get('inputTokens', 0)
self.outputTokens = data.get('outputTokens', 0)

def to_metrics(self) -> TokenMetrics:
"""
Convert Bedrock token usage to metrics format.

:return: TokenMetrics object containing usage data.
"""
return TokenMetrics(
total=self.totalTokens,
input=self.inputTokens,
output=self.outputTokens,
)
total: int
input: int
output: int


class LDAIMetricSummary:
Expand All @@ -154,7 +54,7 @@ def feedback(self) -> Optional[Dict[str, FeedbackKind]]:
return self._feedback

@property
def usage(self) -> Optional[Union[TokenUsage, BedrockTokenUsage]]:
def usage(self) -> Optional[TokenUsage]:
return self._usage


Expand Down Expand Up @@ -255,8 +155,8 @@ def track_openai_metrics(self, func):
:return: Result of the tracked function.
"""
result = self.track_duration_of(func)
if result.usage:
self.track_tokens(OpenAITokenUsage(result.usage))
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
return result

def track_bedrock_converse_metrics(self, res: dict) -> dict:
Expand All @@ -275,37 +175,36 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict:
if res.get('metrics', {}).get('latencyMs'):
self.track_duration(res['metrics']['latencyMs'])
if res.get('usage'):
self.track_tokens(BedrockTokenUsage(res['usage']))
self.track_tokens(_bedrock_to_token_usage(res['usage']))
return res

def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None:
def track_tokens(self, tokens: TokenUsage) -> None:
"""
Track token usage metrics.

:param tokens: Token usage data from either custom, OpenAI, or Bedrock sources.
"""
self._summary._usage = tokens
token_metrics = tokens.to_metrics()
if token_metrics.total > 0:
if tokens.total > 0:
self._ld_client.track(
'$ld:ai:tokens:total',
self._context,
self.__get_track_data(),
token_metrics.total,
tokens.total,
)
if token_metrics.input > 0:
if tokens.input > 0:
self._ld_client.track(
'$ld:ai:tokens:input',
self._context,
self.__get_track_data(),
token_metrics.input,
tokens.input,
)
if token_metrics.output > 0:
if tokens.output > 0:
self._ld_client.track(
'$ld:ai:tokens:output',
self._context,
self.__get_track_data(),
token_metrics.output,
tokens.output,
)

def get_summary(self) -> LDAIMetricSummary:
Expand All @@ -315,3 +214,31 @@ def get_summary(self) -> LDAIMetricSummary:
:return: Summary of AI metrics.
"""
return self._summary


def _bedrock_to_token_usage(data: dict) -> TokenUsage:
"""
Convert a Bedrock usage dictionary to a TokenUsage object.

:param data: Dictionary containing Bedrock usage data.
:return: TokenUsage object containing usage data.
"""
return TokenUsage(
total=data.get('totalTokens', 0),
input=data.get('inputTokens', 0),
output=data.get('outputTokens', 0),
)


def _openai_to_token_usage(data: dict) -> TokenUsage:
"""
Convert an OpenAI usage dictionary to a TokenUsage object.

:param data: Dictionary containing OpenAI usage data.
:return: TokenUsage object containing usage data.
"""
return TokenUsage(
total=data.get('total_tokens', 0),
input=data.get('prompt_tokens', 0),
output=data.get('completion_tokens', 0),
)
Loading