Skip to content

Commit d353d16

Browse files
Change token counter implementation & support embedding token counting (defaults to tiktoken) (#92)
* Change token counter implementation & support embedding token counting (defaults to tiktoken) * Refactor code * Add a todo
1 parent 2f78672 commit d353d16

File tree

7 files changed

+309
-309
lines changed

7 files changed

+309
-309
lines changed

pdm.lock

Lines changed: 204 additions & 204 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies = [
4848
"llama-index-llms-vertex==0.2.2",
4949
"llama-index-llms-replicate==0.1.3",
5050
"llama-index-llms-ollama==0.2.2",
51-
"llama-index-llms-bedrock==0.1.12",
51+
"llama-index-llms-bedrock==0.1.13",
5252
# For Llama Parse X2Text
5353
"llama-parse==0.4.9",
5454
# OCR

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.45.2"
1+
__version__ = "0.46.0"
22

33

44
def get_sdk_version():

src/unstract/sdk/embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from deprecated import deprecated
44
from llama_index.core.base.embeddings.base import Embedding
5+
from llama_index.core.callbacks import CallbackManager as LlamaIndexCallbackManager
56
from llama_index.core.embeddings import BaseEmbedding
67

78
from unstract.sdk.adapter import ToolAdapter
@@ -104,6 +105,17 @@ def get_class_name(self) -> str:
104105
"""
105106
return self._embedding_instance.class_name()
106107

108+
def get_callback_manager(self) -> LlamaIndexCallbackManager:
109+
"""Gets the llama-index callback manager set on the model.
110+
111+
Args:
112+
NA
113+
114+
Returns:
115+
llama-index callback manager
116+
"""
117+
return self._embedding_instance.callback_manager
118+
107119
@deprecated("Use Embedding instead of ToolEmbedding")
108120
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
109121
return self._get_embedding_length()

src/unstract/sdk/index.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def index(
289289
try:
290290
if chunk_size == 0:
291291
parser = SimpleNodeParser.from_defaults(
292-
chunk_size=len(documents[0].text) + 10, chunk_overlap=0
292+
chunk_size=len(documents[0].text) + 10,
293+
chunk_overlap=0,
294+
callback_manager=embedding.get_callback_manager(),
293295
)
294296
nodes = parser.get_nodes_from_documents(
295297
documents, show_progress=True
@@ -301,7 +303,9 @@ def index(
301303
else:
302304
storage_context = vector_db.get_storage_context()
303305
parser = SimpleNodeParser.from_defaults(
304-
chunk_size=chunk_size, chunk_overlap=chunk_overlap
306+
chunk_size=chunk_size,
307+
chunk_overlap=chunk_overlap,
308+
callback_manager=embedding.get_callback_manager(),
305309
)
306310
self.tool.stream_log("Adding nodes to vector db...")
307311
# TODO: Phase 2:
@@ -320,6 +324,7 @@ def index(
320324
show_progress=True,
321325
embed_model=embedding,
322326
node_parser=parser,
327+
callback_manager=embedding.get_callback_manager(),
323328
)
324329
except Exception as e:
325330
self.tool.stream_log(
Lines changed: 81 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,10 @@
1-
from typing import Any
1+
from typing import Any, Union
22

33
from llama_index.core.callbacks.schema import EventPayload
4-
from llama_index.core.utilities.token_counting import TokenCounter
5-
from openai.types import CompletionUsage
6-
from openai.types.chat import ChatCompletion
4+
from llama_index.core.llms import ChatResponse, CompletionResponse
75

86

97
class Constants:
10-
KEY_USAGE = "usage"
11-
KEY_USAGE_METADATA = "usage_metadata"
12-
KEY_EVAL_COUNT = "eval_count"
13-
KEY_PROMPT_EVAL_COUNT = "prompt_eval_count"
14-
KEY_RAW_RESPONSE = "_raw_response"
15-
KEY_TEXT_TOKEN_COUNT = "inputTextTokenCount"
16-
KEY_TOKEN_COUNT = "tokenCount"
17-
KEY_RESULTS = "results"
18-
INPUT_TOKENS = "input_tokens"
19-
OUTPUT_TOKENS = "output_tokens"
20-
PROMPT_TOKENS = "prompt_tokens"
21-
COMPLETION_TOKENS = "completion_tokens"
228
DEFAULT_TOKEN_COUNT = 0
239

2410

@@ -35,69 +21,25 @@ def __init__(self, input_tokens, output_tokens):
3521
self.prompt_llm_token_count + self.completion_llm_token_count
3622
)
3723

24+
# TODO: Add unit test cases for the following function
25+
# for ease of manintenance
3826
@staticmethod
39-
def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter:
27+
def get_llm_token_counts(payload: dict[str, Any]):
4028
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
4129
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
4230
if EventPayload.PROMPT in payload:
43-
completion_raw = payload.get(EventPayload.COMPLETION).raw
44-
if completion_raw:
45-
# For Open AI models, token count is part of ChatCompletion
46-
if isinstance(completion_raw, ChatCompletion):
47-
if hasattr(completion_raw, Constants.KEY_USAGE):
48-
token_counts: dict[
49-
str, int
50-
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
51-
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
52-
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
53-
# For other models
54-
elif isinstance(completion_raw, dict):
55-
# For Gemini models
56-
if completion_raw.get(Constants.KEY_RAW_RESPONSE):
57-
if hasattr(
58-
completion_raw.get(Constants.KEY_RAW_RESPONSE),
59-
Constants.KEY_USAGE_METADATA,
60-
):
61-
usage = completion_raw.get(
62-
Constants.KEY_RAW_RESPONSE
63-
).usage_metadata
64-
prompt_tokens = usage.prompt_token_count
65-
completion_tokens = usage.candidates_token_count
66-
elif completion_raw.get(Constants.KEY_USAGE):
67-
token_counts: dict[
68-
str, int
69-
] = TokenCounter._get_prompt_completion_tokens(completion_raw)
70-
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
71-
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
72-
# For Bedrock models
73-
elif Constants.KEY_TEXT_TOKEN_COUNT in completion_raw:
74-
prompt_tokens = completion_raw[Constants.KEY_TEXT_TOKEN_COUNT]
75-
if Constants.KEY_RESULTS in completion_raw:
76-
result_list: list = completion_raw[Constants.KEY_RESULTS]
77-
if len(result_list) > 0:
78-
result: dict = result_list[0]
79-
if Constants.KEY_TOKEN_COUNT in result:
80-
completion_tokens = result.get(
81-
Constants.KEY_TOKEN_COUNT
82-
)
83-
else:
84-
if completion_raw.get(Constants.KEY_PROMPT_EVAL_COUNT):
85-
prompt_tokens = completion_raw.get(
86-
Constants.KEY_PROMPT_EVAL_COUNT
87-
)
88-
if completion_raw.get(Constants.KEY_EVAL_COUNT):
89-
completion_tokens = completion_raw.get(
90-
Constants.KEY_EVAL_COUNT
91-
)
92-
# For Anthropic models
31+
response = payload.get(EventPayload.COMPLETION)
32+
(
33+
prompt_tokens,
34+
completion_tokens,
35+
) = TokenCounter._get_tokens_from_response(response)
9336
elif EventPayload.MESSAGES in payload:
94-
response_raw = payload.get(EventPayload.RESPONSE).raw
95-
if response_raw:
96-
token_counts: dict[
97-
str, int
98-
] = TokenCounter._get_prompt_completion_tokens(response_raw)
99-
prompt_tokens = token_counts[Constants.PROMPT_TOKENS]
100-
completion_tokens = token_counts[Constants.COMPLETION_TOKENS]
37+
response = payload.get(EventPayload.RESPONSE)
38+
if response:
39+
(
40+
prompt_tokens,
41+
completion_tokens,
42+
) = TokenCounter._get_tokens_from_response(response)
10143

10244
token_counter = TokenCounter(
10345
input_tokens=prompt_tokens,
@@ -106,33 +48,72 @@ def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter:
10648
return token_counter
10749

10850
@staticmethod
109-
def _get_prompt_completion_tokens(response) -> dict[str, int]:
110-
usage = None
111-
prompt_tokens = Constants.DEFAULT_TOKEN_COUNT
112-
completion_tokens = Constants.DEFAULT_TOKEN_COUNT
113-
# For OpenAI models,response is an obj of CompletionUsage
114-
if (
115-
isinstance(response, ChatCompletion)
116-
and hasattr(response, Constants.KEY_USAGE)
117-
and isinstance(response.usage, CompletionUsage)
51+
def _get_tokens_from_response(
52+
response: Union[CompletionResponse, ChatResponse, dict]
53+
) -> tuple[int, int]:
54+
"""Get the token counts from a raw response."""
55+
prompt_tokens, completion_tokens = 0, 0
56+
if isinstance(response, CompletionResponse) or isinstance(
57+
response, ChatResponse
11858
):
119-
usage = response.usage
120-
# For LLM models other than OpenAI, response is a dict
121-
elif isinstance(response, dict) and Constants.KEY_USAGE in response:
122-
usage = response.get(Constants.KEY_USAGE)
59+
raw_response = response.raw
60+
if not isinstance(raw_response, dict):
61+
raw_response = dict(raw_response)
62+
63+
usage = raw_response.get("usage", None)
64+
if usage is None:
65+
if (
66+
hasattr(response, "additional_kwargs")
67+
and "prompt_tokens" in response.additional_kwargs
68+
):
69+
usage = response.additional_kwargs
70+
elif hasattr(response, "raw"):
71+
completion_raw = response.raw
72+
if ("_raw_response" in completion_raw) and hasattr(
73+
completion_raw["_raw_response"], "usage_metadata"
74+
):
75+
usage = completion_raw["_raw_response"].usage_metadata
76+
prompt_tokens = usage.prompt_token_count
77+
completion_tokens = usage.candidates_token_count
78+
return prompt_tokens, completion_tokens
79+
elif "inputTextTokenCount" in completion_raw:
80+
prompt_tokens = completion_raw["inputTextTokenCount"]
81+
if "results" in completion_raw:
82+
result_list: list = completion_raw["results"]
83+
if len(result_list) > 0:
84+
result: dict = result_list[0]
85+
if "tokenCount" in result:
86+
completion_tokens = result.get("tokenCount", 0)
87+
return prompt_tokens, completion_tokens
88+
else:
89+
usage = response.raw
90+
else:
91+
usage = response
92+
93+
if not isinstance(usage, dict):
94+
usage = usage.model_dump()
95+
96+
possible_input_keys = (
97+
"prompt_tokens",
98+
"input_tokens",
99+
"prompt_eval_count",
100+
)
101+
possible_output_keys = (
102+
"completion_tokens",
103+
"output_tokens",
104+
"eval_count",
105+
)
123106

124-
if usage:
125-
if hasattr(usage, Constants.INPUT_TOKENS):
126-
prompt_tokens = usage.input_tokens
127-
elif hasattr(usage, Constants.PROMPT_TOKENS):
128-
prompt_tokens = usage.prompt_tokens
107+
prompt_tokens = 0
108+
for input_key in possible_input_keys:
109+
if input_key in usage:
110+
prompt_tokens = int(usage[input_key])
111+
break
129112

130-
if hasattr(usage, Constants.OUTPUT_TOKENS):
131-
completion_tokens = usage.output_tokens
132-
elif hasattr(usage, Constants.COMPLETION_TOKENS):
133-
completion_tokens = usage.completion_tokens
113+
completion_tokens = 0
114+
for output_key in possible_output_keys:
115+
if output_key in usage:
116+
completion_tokens = int(usage[output_key])
117+
break
134118

135-
token_counts: dict[str, int] = dict()
136-
token_counts[Constants.PROMPT_TOKENS] = prompt_tokens
137-
token_counts[Constants.COMPLETION_TOKENS] = completion_tokens
138-
return token_counts
119+
return prompt_tokens, completion_tokens

src/unstract/sdk/vector_db.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def get_vector_store_index_from_storage_context(
124124
documents: Sequence[Document],
125125
storage_context: Optional[StorageContext] = None,
126126
show_progress: bool = False,
127+
callback_manager=None,
127128
**kwargs,
128129
) -> IndexType:
129130
if not self._embedding_instance:
@@ -135,6 +136,7 @@ def get_vector_store_index_from_storage_context(
135136
show_progress=show_progress,
136137
embed_model=self._embedding_instance,
137138
node_parser=parser,
139+
callback_manager=callback_manager,
138140
)
139141

140142
def get_vector_store_index(self, **kwargs: Any) -> VectorStoreIndex:
@@ -143,7 +145,7 @@ def get_vector_store_index(self, **kwargs: Any) -> VectorStoreIndex:
143145
return VectorStoreIndex.from_vector_store(
144146
vector_store=self._vector_db_instance,
145147
embed_model=self._embedding_instance,
146-
kwargs=kwargs,
148+
callback_manager=kwargs.get("callback_manager"),
147149
)
148150

149151
def get_storage_context(self) -> StorageContext:

0 commit comments

Comments
 (0)