Skip to content

Commit a9e91dc

Browse files
committed
fix!: Unify tracking token to use only TokenUsage
1 parent e425b1f commit a9e91dc

File tree

2 files changed

+127
-121
lines changed

2 files changed

+127
-121
lines changed

ldai/testing/test_tracker.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import MagicMock, call
22

33
import pytest
44
from ldclient import Config, Context, LDClient
55
from ldclient.integrations.test_data import TestData
66

7-
from ldai.tracker import FeedbackKind, LDAIConfigTracker
7+
from ldai.tracker import FeedbackKind, LDAIConfigTracker, TokenUsage
88

99

1010
@pytest.fixture
@@ -60,6 +60,85 @@ def test_tracks_duration(client: LDClient):
6060
assert tracker.get_summary().duration == 100
6161

6262

63+
def test_tracks_token_usage(client: LDClient):
64+
context = Context.create('user-key')
65+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
66+
67+
tokens = TokenUsage(300, 200, 100)
68+
tracker.track_tokens(tokens)
69+
70+
calls = [
71+
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 300),
72+
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 200),
73+
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 100),
74+
]
75+
76+
client.track.assert_has_calls(calls) # type: ignore
77+
78+
assert tracker.get_summary().usage == tokens
79+
80+
81+
def test_tracks_bedrock_metrics(client: LDClient):
82+
context = Context.create('user-key')
83+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
84+
85+
bedrock_result = {
86+
'$metadata': {'httpStatusCode': 200},
87+
'usage': {
88+
'totalTokens': 330,
89+
'inputTokens': 220,
90+
'outputTokens': 110,
91+
},
92+
'metrics': {
93+
'latencyMs': 50,
94+
}
95+
}
96+
tracker.track_bedrock_converse_metrics(bedrock_result)
97+
98+
calls = [
99+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
100+
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
101+
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
102+
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
103+
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
104+
]
105+
106+
client.track.assert_has_calls(calls) # type: ignore
107+
108+
assert tracker.get_summary().success is True
109+
assert tracker.get_summary().duration == 50
110+
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
111+
112+
113+
def test_tracks_openai_metrics(client: LDClient):
114+
context = Context.create('user-key')
115+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
116+
117+
class Result:
118+
def __init__(self):
119+
self.usage = Usage()
120+
121+
class Usage:
122+
def to_dict(self):
123+
return {
124+
'total_tokens': 330,
125+
'prompt_tokens': 220,
126+
'completion_tokens': 110,
127+
}
128+
129+
tracker.track_openai_metrics(lambda: Result())
130+
131+
calls = [
132+
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
133+
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
134+
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
135+
]
136+
137+
client.track.assert_has_calls(calls, any_order=False) # type: ignore
138+
139+
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
140+
141+
63142
@pytest.mark.parametrize(
64143
"kind,label",
65144
[

ldai/tracker.py

Lines changed: 46 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
11
import time
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Dict, Optional, Union
4+
from typing import Dict, Optional
55

66
from ldclient import Context, LDClient
77

88

9-
@dataclass
10-
class TokenMetrics:
11-
"""
12-
Metrics for token usage in AI operations.
13-
14-
:param total: Total number of tokens used.
15-
:param input: Number of input tokens.
16-
:param output: Number of output tokens.
17-
"""
18-
19-
total: int
20-
input: int
21-
output: int # type: ignore
22-
23-
249
class FeedbackKind(Enum):
2510
"""
2611
Types of feedback that can be provided for AI operations.
@@ -35,99 +20,14 @@ class TokenUsage:
3520
"""
3621
Tracks token usage for AI operations.
3722
38-
:param total_tokens: Total number of tokens used.
39-
:param prompt_tokens: Number of tokens in the prompt.
40-
:param completion_tokens: Number of tokens in the completion.
41-
"""
42-
43-
total_tokens: int
44-
prompt_tokens: int
45-
completion_tokens: int
46-
47-
def to_metrics(self):
48-
"""
49-
Convert token usage to metrics format.
50-
51-
:return: Dictionary containing token metrics.
52-
"""
53-
return {
54-
'total': self['total_tokens'],
55-
'input': self['prompt_tokens'],
56-
'output': self['completion_tokens'],
57-
}
58-
59-
60-
@dataclass
61-
class LDOpenAIUsage:
62-
"""
63-
LaunchDarkly-specific OpenAI usage tracking.
64-
65-
:param total_tokens: Total number of tokens used.
66-
:param prompt_tokens: Number of tokens in the prompt.
67-
:param completion_tokens: Number of tokens in the completion.
68-
"""
69-
70-
total_tokens: int
71-
prompt_tokens: int
72-
completion_tokens: int
73-
74-
75-
@dataclass
76-
class OpenAITokenUsage:
77-
"""
78-
Tracks OpenAI-specific token usage.
79-
"""
80-
81-
def __init__(self, data: LDOpenAIUsage):
82-
"""
83-
Initialize OpenAI token usage tracking.
84-
85-
:param data: OpenAI usage data.
86-
"""
87-
self.total_tokens = data.total_tokens
88-
self.prompt_tokens = data.prompt_tokens
89-
self.completion_tokens = data.completion_tokens
90-
91-
def to_metrics(self) -> TokenMetrics:
92-
"""
93-
Convert OpenAI token usage to metrics format.
94-
95-
:return: TokenMetrics object containing usage data.
96-
"""
97-
return TokenMetrics(
98-
total=self.total_tokens,
99-
input=self.prompt_tokens,
100-
output=self.completion_tokens,
101-
)
102-
103-
104-
@dataclass
105-
class BedrockTokenUsage:
106-
"""
107-
Tracks AWS Bedrock-specific token usage.
23+
:param total: Total number of tokens used.
24+
:param input: Number of tokens in the prompt.
25+
:param output: Number of tokens in the completion.
10826
"""
10927

110-
def __init__(self, data: dict):
111-
"""
112-
Initialize Bedrock token usage tracking.
113-
114-
:param data: Dictionary containing Bedrock usage data.
115-
"""
116-
self.totalTokens = data.get('totalTokens', 0)
117-
self.inputTokens = data.get('inputTokens', 0)
118-
self.outputTokens = data.get('outputTokens', 0)
119-
120-
def to_metrics(self) -> TokenMetrics:
121-
"""
122-
Convert Bedrock token usage to metrics format.
123-
124-
:return: TokenMetrics object containing usage data.
125-
"""
126-
return TokenMetrics(
127-
total=self.totalTokens,
128-
input=self.inputTokens,
129-
output=self.outputTokens,
130-
)
28+
total: int
29+
input: int
30+
output: int
13131

13232

13333
class LDAIMetricSummary:
@@ -154,7 +54,7 @@ def feedback(self) -> Optional[Dict[str, FeedbackKind]]:
15454
return self._feedback
15555

15656
@property
157-
def usage(self) -> Optional[Union[TokenUsage, BedrockTokenUsage]]:
57+
def usage(self) -> Optional[TokenUsage]:
15858
return self._usage
15959

16060

@@ -255,8 +155,8 @@ def track_openai_metrics(self, func):
255155
:return: Result of the tracked function.
256156
"""
257157
result = self.track_duration_of(func)
258-
if result.usage:
259-
self.track_tokens(OpenAITokenUsage(result.usage))
158+
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
159+
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
260160
return result
261161

262162
def track_bedrock_converse_metrics(self, res: dict) -> dict:
@@ -275,37 +175,36 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict:
275175
if res.get('metrics', {}).get('latencyMs'):
276176
self.track_duration(res['metrics']['latencyMs'])
277177
if res.get('usage'):
278-
self.track_tokens(BedrockTokenUsage(res['usage']))
178+
self.track_tokens(_bedrock_to_token_usage(res['usage']))
279179
return res
280180

281-
def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None:
181+
def track_tokens(self, tokens: TokenUsage) -> None:
282182
"""
283183
Track token usage metrics.
284184
285185
:param tokens: Token usage data from either custom, OpenAI, or Bedrock sources.
286186
"""
287187
self._summary._usage = tokens
288-
token_metrics = tokens.to_metrics()
289-
if token_metrics.total > 0:
188+
if tokens.total > 0:
290189
self._ld_client.track(
291190
'$ld:ai:tokens:total',
292191
self._context,
293192
self.__get_track_data(),
294-
token_metrics.total,
193+
tokens.total,
295194
)
296-
if token_metrics.input > 0:
195+
if tokens.input > 0:
297196
self._ld_client.track(
298197
'$ld:ai:tokens:input',
299198
self._context,
300199
self.__get_track_data(),
301-
token_metrics.input,
200+
tokens.input,
302201
)
303-
if token_metrics.output > 0:
202+
if tokens.output > 0:
304203
self._ld_client.track(
305204
'$ld:ai:tokens:output',
306205
self._context,
307206
self.__get_track_data(),
308-
token_metrics.output,
207+
tokens.output,
309208
)
310209

311210
def get_summary(self) -> LDAIMetricSummary:
@@ -315,3 +214,31 @@ def get_summary(self) -> LDAIMetricSummary:
315214
:return: Summary of AI metrics.
316215
"""
317216
return self._summary
217+
218+
219+
def _bedrock_to_token_usage(data: dict) -> TokenUsage:
220+
"""
221+
Convert a Bedrock usage dictionary to a TokenUsage object.
222+
223+
:param data: Dictionary containing Bedrock usage data.
224+
:return: TokenUsage object containing usage data.
225+
"""
226+
return TokenUsage(
227+
total=data.get('totalTokens', 0),
228+
input=data.get('inputTokens', 0),
229+
output=data.get('outputTokens', 0),
230+
)
231+
232+
233+
def _openai_to_token_usage(data: dict) -> TokenUsage:
234+
"""
235+
Convert an OpenAI usage dictionary to a TokenUsage object.
236+
237+
:param data: Dictionary containing OpenAI usage data.
238+
:return: TokenUsage object containing usage data.
239+
"""
240+
return TokenUsage(
241+
total=data.get('total_tokens', 0),
242+
input=data.get('prompt_tokens', 0),
243+
output=data.get('completion_tokens', 0),
244+
)

0 commit comments

Comments
 (0)