Skip to content

Commit 3886510

Browse files
Support llm token counting using llama index response (#63)
Co-authored-by: Deepak K <[email protected]>
1 parent 76b4ed9 commit 3886510

File tree

4 files changed

+142
-21
lines changed

4 files changed

+142
-21
lines changed

src/unstract/sdk/audit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any
1+
from typing import Any, Union
22

33
import requests
44
from llama_index.core.callbacks import CBEventType, TokenCountingHandler
55

66
from unstract.sdk.constants import LogLevel, ToolEnv
77
from unstract.sdk.helper import SdkHelper
88
from unstract.sdk.tool.stream import StreamMixin
9+
from unstract.sdk.utils.token_counter import TokenCounter
910

1011

1112
class Audit(StreamMixin):
@@ -25,7 +26,7 @@ def __init__(self, log_level: LogLevel = LogLevel.INFO) -> None:
2526
def push_usage_data(
2627
self,
2728
platform_api_key: str,
28-
token_counter: TokenCountingHandler = None,
29+
token_counter: Union[TokenCountingHandler, TokenCounter] = None,
2930
model_name: str = "",
3031
event_type: CBEventType = None,
3132
kwargs: dict[Any, Any] = None,
@@ -105,4 +106,5 @@ def push_usage_data(
105106
)
106107

107108
finally:
108-
token_counter.reset_counts()
109+
if isinstance(token_counter, TokenCountingHandler):
110+
token_counter.reset_counts()

src/unstract/sdk/utils/callback_manager.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from llama_index.core.callbacks import TokenCountingHandler
77
from llama_index.core.embeddings import BaseEmbedding
88
from llama_index.core.llms import LLM
9-
from transformers import AutoTokenizer
109
from typing_extensions import deprecated
1110

1211
from unstract.sdk.utils.usage_handler import UsageHandler
@@ -77,24 +76,35 @@ def get_callback_manager(
7776
platform_api_key: str,
7877
kwargs,
7978
) -> LlamaIndexCallbackManager:
80-
tokenizer = CallbackManager.get_tokenizer(model)
81-
token_counter = TokenCountingHandler(tokenizer=tokenizer, verbose=True)
8279
llm = None
8380
embedding = None
81+
handler_list = []
8482
if isinstance(model, LLM):
8583
llm = model
84+
usage_handler = UsageHandler(
85+
platform_api_key=platform_api_key,
86+
llm_model=llm,
87+
embed_model=embedding,
88+
kwargs=kwargs,
89+
)
90+
handler_list.append(usage_handler)
8691
elif isinstance(model, BaseEmbedding):
8792
embedding = model
88-
usage_handler = UsageHandler(
89-
token_counter=token_counter,
90-
platform_api_key=platform_api_key,
91-
llm_model=llm,
92-
embed_model=embedding,
93-
kwargs=kwargs,
94-
)
93+
# Get a tokenizer
94+
tokenizer = CallbackManager.get_tokenizer(model)
95+
token_counter = TokenCountingHandler(tokenizer=tokenizer, verbose=True)
96+
usage_handler = UsageHandler(
97+
token_counter=token_counter,
98+
platform_api_key=platform_api_key,
99+
llm_model=llm,
100+
embed_model=embedding,
101+
kwargs=kwargs,
102+
)
103+
handler_list.append(token_counter)
104+
handler_list.append(usage_handler)
95105

96106
callback_manager: LlamaIndexCallbackManager = LlamaIndexCallbackManager(
97-
handlers=[token_counter, usage_handler]
107+
handlers=handler_list
98108
)
99109
return callback_manager
100110

@@ -124,11 +134,11 @@ def get_tokenizer(
124134
elif isinstance(model, BaseEmbedding):
125135
model_name = model.model_name
126136

127-
tokenizer: Callable[[str], list] = AutoTokenizer.from_pretrained(
137+
tokenizer: Callable[[str], list] = tiktoken.encoding_for_model(
128138
model_name
129139
).encode
130140
return tokenizer
131-
except OSError as e:
141+
except ValueError as e:
132142
logger.warning(str(e))
133143
return fallback_tokenizer
134144

@@ -145,8 +155,6 @@ def set_callback_manager(
145155
CallbackManager.set_callback(platform_api_key, model=llm, **kwargs)
146156
callback_manager = llm.callback_manager
147157
if embedding:
148-
CallbackManager.set_callback_manager(
149-
platform_api_key, model=embedding, **kwargs
150-
)
158+
CallbackManager.set_callback(platform_api_key, model=embedding, **kwargs)
151159
callback_manager = embedding.callback_manager
152160
return callback_manager
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Any
2+
3+
from llama_index.core.callbacks.schema import EventPayload
4+
from llama_index.core.utilities.token_counting import TokenCounter
5+
6+
7+
class Constants:
8+
KEY_USAGE = "usage"
9+
KEY_USAGE_METADATA = "usage_metadata"
10+
KEY_EVAL_COUNT = "eval_count"
11+
KEY_PROMPT_EVAL_COUNT = "prompt_eval_count"
12+
KEY_RAW_RESPONSE = "_raw_response"
13+
INPUT_TOKENS = "input_tokens"
14+
OUTPUT_TOKENS = "output_tokens"
15+
PROMPT_TOKENS = "prompt_tokens"
16+
COMPLETION_TOKENS = "completion_tokens"
17+
DEFAULT_TOKEN_COUNT = 0
18+
19+
20+
class TokenCounter:
21+
prompt_llm_token_count: int
22+
completion_llm_token_count: int
23+
total_llm_token_count: int = 0
24+
total_embedding_token_count: int = 0
25+
26+
def __init__(self, input_tokens, output_tokens):
27+
self.prompt_llm_token_count = input_tokens
28+
self.completion_llm_token_count = output_tokens
29+
self.total_llm_token_count = (
30+
self.prompt_llm_token_count + self.completion_llm_token_count
31+
)
32+
33+
@staticmethod
34+
def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter:
35+
token_counter = TokenCounter(
36+
input_tokens=Constants.DEFAULT_TOKEN_COUNT,
37+
output_tokens=Constants.DEFAULT_TOKEN_COUNT,
38+
)
39+
if EventPayload.PROMPT in payload:
40+
completion_raw = payload.get(EventPayload.COMPLETION).raw
41+
if completion_raw:
42+
if completion_raw.get(Constants.KEY_USAGE):
43+
token_counts: dict[
44+
str, int
45+
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
46+
token_counter = TokenCounter(
47+
input_tokens=token_counts[Constants.PROMPT_TOKENS],
48+
output_tokens=token_counts[Constants.COMPLETION_TOKENS],
49+
)
50+
elif completion_raw.get(Constants.KEY_RAW_RESPONSE):
51+
if hasattr(
52+
completion_raw.get(Constants.KEY_RAW_RESPONSE),
53+
Constants.KEY_USAGE_METADATA,
54+
):
55+
usage = completion_raw.get(
56+
Constants.KEY_RAW_RESPONSE
57+
).usage_metadata
58+
token_counter = TokenCounter(
59+
input_tokens=usage.prompt_token_count,
60+
output_tokens=usage.candidates_token_count,
61+
)
62+
else:
63+
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
64+
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
65+
if completion_raw.get(Constants.KEY_PROMPT_EVAL_COUNT):
66+
prompt_tokens = completion_raw.get(
67+
Constants.KEY_PROMPT_EVAL_COUNT
68+
)
69+
if completion_raw.get(Constants.KEY_EVAL_COUNT):
70+
completion_tokens = completion_raw.get(Constants.KEY_EVAL_COUNT)
71+
token_counter = TokenCounter(
72+
input_tokens=prompt_tokens,
73+
output_tokens=completion_tokens,
74+
)
75+
elif EventPayload.MESSAGES in payload:
76+
response_raw = payload.get(EventPayload.RESPONSE).raw
77+
if response_raw:
78+
token_counts: dict[
79+
str, int
80+
] = TokenCounter._get_prompt_completion_tokens(response_raw)
81+
token_counter = TokenCounter(
82+
input_tokens=token_counts[Constants.PROMPT_TOKENS],
83+
output_tokens=token_counts[Constants.COMPLETION_TOKENS],
84+
)
85+
86+
return token_counter
87+
88+
@staticmethod
89+
def _get_prompt_completion_tokens(response) -> dict[str, int]:
90+
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
91+
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
92+
93+
usage = response.get(Constants.KEY_USAGE)
94+
if usage:
95+
if hasattr(usage, Constants.INPUT_TOKENS):
96+
prompt_tokens = usage.input_tokens
97+
elif hasattr(usage, Constants.PROMPT_TOKENS):
98+
prompt_tokens = usage.prompt_tokens
99+
100+
if hasattr(usage, Constants.OUTPUT_TOKENS):
101+
completion_tokens = usage.output_tokens
102+
elif hasattr(usage, Constants.COMPLETION_TOKENS):
103+
completion_tokens = usage.completion_tokens
104+
105+
token_counts: dict[str, int] = dict()
106+
token_counts[Constants.PROMPT_TOKENS] = prompt_tokens
107+
token_counts[Constants.COMPLETION_TOKENS] = completion_tokens
108+
return token_counts

src/unstract/sdk/utils/usage_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unstract.sdk.audit import Audit
99
from unstract.sdk.constants import LogLevel
1010
from unstract.sdk.tool.stream import StreamMixin
11+
from unstract.sdk.utils.token_counter import TokenCounter
1112

1213

1314
class UsageHandler(StreamMixin, BaseCallbackHandler):
@@ -32,8 +33,8 @@ class UsageHandler(StreamMixin, BaseCallbackHandler):
3233

3334
def __init__(
3435
self,
35-
token_counter: TokenCountingHandler,
3636
platform_api_key: str,
37+
token_counter: Optional[TokenCountingHandler] = None,
3738
llm_model: LLM = None,
3839
embed_model: BaseEmbedding = None,
3940
event_starts_to_ignore: Optional[list[CBEventType]] = None,
@@ -90,9 +91,10 @@ def on_event_end(
9091
model_name = self.llm_model.metadata.model_name
9192
# Need to push the data to via platform service
9293
self.stream_log(log=f"Pushing llm usage for model {model_name}")
94+
llm_token_counter: TokenCounter = TokenCounter.get_llm_token_counts(payload)
9395
Audit(log_level=self.log_level).push_usage_data(
9496
platform_api_key=self.platform_api_key,
95-
token_counter=self.token_counter,
97+
token_counter=llm_token_counter,
9698
event_type=event_type,
9799
model_name=self.llm_model.metadata.model_name,
98100
kwargs=self.kwargs,
@@ -113,3 +115,4 @@ def on_event_end(
113115
model_name=self.embed_model.model_name,
114116
kwargs=self.kwargs,
115117
)
118+
self.token_counter.reset_counts()

0 commit comments

Comments
 (0)