Skip to content

Commit e2a2595

Browse files
Token counter regression (#83)
* Llama-index-upgrade * Upgrade SDK version * pdm lock file update * Token counter fixes
1 parent 4acfa68 commit e2a2595

File tree

1 file changed

+73
-39
lines changed

1 file changed

+73
-39
lines changed

src/unstract/sdk/utils/token_counter.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from llama_index.core.callbacks.schema import EventPayload
44
from llama_index.core.utilities.token_counting import TokenCounter
5+
from openai.types import CompletionUsage
6+
from openai.types.chat import ChatCompletion
57

68

79
class Constants:
@@ -10,6 +12,9 @@ class Constants:
1012
KEY_EVAL_COUNT = "eval_count"
1113
KEY_PROMPT_EVAL_COUNT = "prompt_eval_count"
1214
KEY_RAW_RESPONSE = "_raw_response"
15+
KEY_TEXT_TOKEN_COUNT = "inputTextTokenCount"
16+
KEY_TOKEN_COUNT = "tokenCount"
17+
KEY_RESULTS = "results"
1318
INPUT_TOKENS = "input_tokens"
1419
OUTPUT_TOKENS = "output_tokens"
1520
PROMPT_TOKENS = "prompt_tokens"
@@ -32,62 +37,91 @@ def __init__(self, input_tokens, output_tokens):
3237

3338
@staticmethod
3439
def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter:
35-
token_counter = TokenCounter(
36-
input_tokens=Constants.DEFAULT_TOKEN_COUNT,
37-
output_tokens=Constants.DEFAULT_TOKEN_COUNT,
38-
)
40+
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
41+
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
3942
if EventPayload.PROMPT in payload:
4043
completion_raw = payload.get(EventPayload.COMPLETION).raw
4144
if completion_raw:
42-
if hasattr(completion_raw, Constants.KEY_USAGE):
43-
token_counts: dict[
44-
str, int
45-
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
46-
token_counter = TokenCounter(
47-
input_tokens=token_counts[Constants.PROMPT_TOKENS],
48-
output_tokens=token_counts[Constants.COMPLETION_TOKENS],
49-
)
50-
elif hasattr(completion_raw, Constants.KEY_RAW_RESPONSE):
51-
if hasattr(
52-
completion_raw._raw_response,
53-
Constants.KEY_USAGE_METADATA,
54-
):
55-
usage = completion_raw._raw_response.usage_metadata
56-
token_counter = TokenCounter(
57-
input_tokens=usage.prompt_token_count,
58-
output_tokens=usage.candidates_token_count,
59-
)
60-
else:
61-
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
62-
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
63-
if hasattr(completion_raw, Constants.KEY_PROMPT_EVAL_COUNT):
64-
prompt_tokens = completion_raw.prompt_eval_count
65-
if hasattr(completion_raw, Constants.KEY_EVAL_COUNT):
66-
completion_tokens = completion_raw.eval_count
67-
token_counter = TokenCounter(
68-
input_tokens=prompt_tokens,
69-
output_tokens=completion_tokens,
70-
)
45+
# For Open AI models, token count is part of ChatCompletion
46+
if isinstance(completion_raw, ChatCompletion):
47+
if hasattr(completion_raw, Constants.KEY_USAGE):
48+
token_counts: dict[
49+
str, int
50+
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
51+
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
52+
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
53+
# For other models
54+
elif isinstance(completion_raw, dict):
55+
# For Gemini models
56+
if completion_raw.get(Constants.KEY_RAW_RESPONSE):
57+
if hasattr(
58+
completion_raw.get(Constants.KEY_RAW_RESPONSE),
59+
Constants.KEY_USAGE_METADATA,
60+
):
61+
usage = completion_raw.get(
62+
Constants.KEY_RAW_RESPONSE
63+
).usage_metadata
64+
prompt_tokens = usage.prompt_token_count
65+
completion_tokens = usage.candidates_token_count
66+
elif completion_raw.get(Constants.KEY_USAGE):
67+
token_counts: dict[
68+
str, int
69+
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
70+
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
71+
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
72+
# For Bedrock models
73+
elif Constants.KEY_TEXT_TOKEN_COUNT in completion_raw:
74+
prompt_tokens = completion_raw[Constants.KEY_TEXT_TOKEN_COUNT]
75+
if Constants.KEY_RESULTS in completion_raw:
76+
result_list: list = completion_raw[Constants.KEY_RESULTS]
77+
if len(result_list) > 0:
78+
result: dict = result_list[0]
79+
if Constants.KEY_TOKEN_COUNT in result:
80+
completion_tokens = result.get(
81+
Constants.KEY_TOKEN_COUNT
82+
)
83+
else:
84+
if completion_raw.get(Constants.KEY_PROMPT_EVAL_COUNT):
85+
prompt_tokens = completion_raw.get(
86+
Constants.KEY_PROMPT_EVAL_COUNT
87+
)
88+
if completion_raw.get(Constants.KEY_EVAL_COUNT):
89+
completion_tokens = completion_raw.get(
90+
Constants.KEY_EVAL_COUNT
91+
)
92+
# For Anthropic models
7193
elif EventPayload.MESSAGES in payload:
7294
response_raw = payload.get(EventPayload.RESPONSE).raw
7395
if response_raw:
7496
token_counts: dict[
7597
str, int
7698
] = TokenCounter._get_prompt_completion_tokens(response_raw)
77-
token_counter = TokenCounter(
78-
input_tokens=token_counts[Constants.PROMPT_TOKENS],
79-
output_tokens=token_counts[Constants.COMPLETION_TOKENS],
80-
)
99+
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
100+
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
81101

102+
token_counter = TokenCounter(
103+
input_tokens=prompt_tokens,
104+
output_tokens=completion_tokens,
105+
)
82106
return token_counter
83107

84108
@staticmethod
85109
def _get_prompt_completion_tokens(response) -> dict[str, int]:
110+
usage = None
86111
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
87112
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
88-
89-
if hasattr(response, Constants.KEY_USAGE):
113+
# For OpenAI models,response is an obj of CompletionUsage
114+
if (
115+
isinstance(response, ChatCompletion)
116+
and hasattr(response, Constants.KEY_USAGE)
117+
and isinstance(response.usage, CompletionUsage)
118+
):
90119
usage = response.usage
120+
# For LLM models other than OpenAI, response is a dict
121+
elif isinstance(response, dict) and Constants.KEY_USAGE in response:
122+
usage = response.get(Constants.KEY_USAGE)
123+
124+
if usage:
91125
if hasattr(usage, Constants.INPUT_TOKENS):
92126
prompt_tokens = usage.input_tokens
93127
elif hasattr(usage, Constants.PROMPT_TOKENS):

0 commit comments

Comments
 (0)