1- from typing import Any
1+ from typing import Any , Union
22
33from llama_index .core .callbacks .schema import EventPayload
4- from llama_index .core .utilities .token_counting import TokenCounter
5- from openai .types import CompletionUsage
6- from openai .types .chat import ChatCompletion
4+ from llama_index .core .llms import ChatResponse , CompletionResponse
75
86
97class Constants :
10- KEY_USAGE = "usage"
11- KEY_USAGE_METADATA = "usage_metadata"
12- KEY_EVAL_COUNT = "eval_count"
13- KEY_PROMPT_EVAL_COUNT = "prompt_eval_count"
14- KEY_RAW_RESPONSE = "_raw_response"
15- KEY_TEXT_TOKEN_COUNT = "inputTextTokenCount"
16- KEY_TOKEN_COUNT = "tokenCount"
17- KEY_RESULTS = "results"
18- INPUT_TOKENS = "input_tokens"
19- OUTPUT_TOKENS = "output_tokens"
20- PROMPT_TOKENS = "prompt_tokens"
21- COMPLETION_TOKENS = "completion_tokens"
228 DEFAULT_TOKEN_COUNT = 0
239
2410
@@ -35,69 +21,25 @@ def __init__(self, input_tokens, output_tokens):
3521 self .prompt_llm_token_count + self .completion_llm_token_count
3622 )
3723
24+ # TODO: Add unit test cases for the following function
25+ # for ease of manintenance
3826 @staticmethod
39- def get_llm_token_counts (payload : dict [str , Any ]) -> TokenCounter :
27+ def get_llm_token_counts (payload : dict [str , Any ]):
4028 prompt_tokens = Constants .DEFAULT_TOKEN_COUNT
4129 completion_tokens = Constants .DEFAULT_TOKEN_COUNT
4230 if EventPayload .PROMPT in payload :
43- completion_raw = payload .get (EventPayload .COMPLETION ).raw
44- if completion_raw :
45- # For Open AI models, token count is part of ChatCompletion
46- if isinstance (completion_raw , ChatCompletion ):
47- if hasattr (completion_raw , Constants .KEY_USAGE ):
48- token_counts : dict [
49- str , int
50- ] = TokenCounter ._get_prompt_completion_tokens (completion_raw )
51- prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
52- completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
53- # For other models
54- elif isinstance (completion_raw , dict ):
55- # For Gemini models
56- if completion_raw .get (Constants .KEY_RAW_RESPONSE ):
57- if hasattr (
58- completion_raw .get (Constants .KEY_RAW_RESPONSE ),
59- Constants .KEY_USAGE_METADATA ,
60- ):
61- usage = completion_raw .get (
62- Constants .KEY_RAW_RESPONSE
63- ).usage_metadata
64- prompt_tokens = usage .prompt_token_count
65- completion_tokens = usage .candidates_token_count
66- elif completion_raw .get (Constants .KEY_USAGE ):
67- token_counts : dict [
68- str , int
69- ] = TokenCounter ._get_prompt_completion_tokens (completion_raw )
70- prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
71- completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
72- # For Bedrock models
73- elif Constants .KEY_TEXT_TOKEN_COUNT in completion_raw :
74- prompt_tokens = completion_raw [Constants .KEY_TEXT_TOKEN_COUNT ]
75- if Constants .KEY_RESULTS in completion_raw :
76- result_list : list = completion_raw [Constants .KEY_RESULTS ]
77- if len (result_list ) > 0 :
78- result : dict = result_list [0 ]
79- if Constants .KEY_TOKEN_COUNT in result :
80- completion_tokens = result .get (
81- Constants .KEY_TOKEN_COUNT
82- )
83- else :
84- if completion_raw .get (Constants .KEY_PROMPT_EVAL_COUNT ):
85- prompt_tokens = completion_raw .get (
86- Constants .KEY_PROMPT_EVAL_COUNT
87- )
88- if completion_raw .get (Constants .KEY_EVAL_COUNT ):
89- completion_tokens = completion_raw .get (
90- Constants .KEY_EVAL_COUNT
91- )
92- # For Anthropic models
31+ response = payload .get (EventPayload .COMPLETION )
32+ (
33+ prompt_tokens ,
34+ completion_tokens ,
35+ ) = TokenCounter ._get_tokens_from_response (response )
9336 elif EventPayload .MESSAGES in payload :
94- response_raw = payload .get (EventPayload .RESPONSE ).raw
95- if response_raw :
96- token_counts : dict [
97- str , int
98- ] = TokenCounter ._get_prompt_completion_tokens (response_raw )
99- prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
100- completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
37+ response = payload .get (EventPayload .RESPONSE )
38+ if response :
39+ (
40+ prompt_tokens ,
41+ completion_tokens ,
42+ ) = TokenCounter ._get_tokens_from_response (response )
10143
10244 token_counter = TokenCounter (
10345 input_tokens = prompt_tokens ,
@@ -106,33 +48,72 @@ def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter:
10648 return token_counter
10749
10850 @staticmethod
109- def _get_prompt_completion_tokens (response ) -> dict [str , int ]:
110- usage = None
111- prompt_tokens = Constants .DEFAULT_TOKEN_COUNT
112- completion_tokens = Constants .DEFAULT_TOKEN_COUNT
113- # For OpenAI models,response is an obj of CompletionUsage
114- if (
115- isinstance (response , ChatCompletion )
116- and hasattr (response , Constants .KEY_USAGE )
117- and isinstance (response .usage , CompletionUsage )
51+ def _get_tokens_from_response (
52+ response : Union [CompletionResponse , ChatResponse , dict ]
53+ ) -> tuple [int , int ]:
54+ """Get the token counts from a raw response."""
55+ prompt_tokens , completion_tokens = 0 , 0
56+ if isinstance (response , CompletionResponse ) or isinstance (
57+ response , ChatResponse
11858 ):
119- usage = response .usage
120- # For LLM models other than OpenAI, response is a dict
121- elif isinstance (response , dict ) and Constants .KEY_USAGE in response :
122- usage = response .get (Constants .KEY_USAGE )
59+ raw_response = response .raw
60+ if not isinstance (raw_response , dict ):
61+ raw_response = dict (raw_response )
62+
63+ usage = raw_response .get ("usage" , None )
64+ if usage is None :
65+ if (
66+ hasattr (response , "additional_kwargs" )
67+ and "prompt_tokens" in response .additional_kwargs
68+ ):
69+ usage = response .additional_kwargs
70+ elif hasattr (response , "raw" ):
71+ completion_raw = response .raw
72+ if ("_raw_response" in completion_raw ) and hasattr (
73+ completion_raw ["_raw_response" ], "usage_metadata"
74+ ):
75+ usage = completion_raw ["_raw_response" ].usage_metadata
76+ prompt_tokens = usage .prompt_token_count
77+ completion_tokens = usage .candidates_token_count
78+ return prompt_tokens , completion_tokens
79+ elif "inputTextTokenCount" in completion_raw :
80+ prompt_tokens = completion_raw ["inputTextTokenCount" ]
81+ if "results" in completion_raw :
82+ result_list : list = completion_raw ["results" ]
83+ if len (result_list ) > 0 :
84+ result : dict = result_list [0 ]
85+ if "tokenCount" in result :
86+ completion_tokens = result .get ("tokenCount" , 0 )
87+ return prompt_tokens , completion_tokens
88+ else :
89+ usage = response .raw
90+ else :
91+ usage = response
92+
93+ if not isinstance (usage , dict ):
94+ usage = usage .model_dump ()
95+
96+ possible_input_keys = (
97+ "prompt_tokens" ,
98+ "input_tokens" ,
99+ "prompt_eval_count" ,
100+ )
101+ possible_output_keys = (
102+ "completion_tokens" ,
103+ "output_tokens" ,
104+ "eval_count" ,
105+ )
123106
124- if usage :
125- if hasattr ( usage , Constants . INPUT_TOKENS ) :
126- prompt_tokens = usage . input_tokens
127- elif hasattr (usage , Constants . PROMPT_TOKENS ):
128- prompt_tokens = usage . prompt_tokens
107+ prompt_tokens = 0
108+ for input_key in possible_input_keys :
109+ if input_key in usage :
110+ prompt_tokens = int (usage [ input_key ])
111+ break
129112
130- if hasattr (usage , Constants .OUTPUT_TOKENS ):
131- completion_tokens = usage .output_tokens
132- elif hasattr (usage , Constants .COMPLETION_TOKENS ):
133- completion_tokens = usage .completion_tokens
113+ completion_tokens = 0
114+ for output_key in possible_output_keys :
115+ if output_key in usage :
116+ completion_tokens = int (usage [output_key ])
117+ break
134118
135- token_counts : dict [str , int ] = dict ()
136- token_counts [Constants .PROMPT_TOKENS ] = prompt_tokens
137- token_counts [Constants .COMPLETION_TOKENS ] = completion_tokens
138- return token_counts
119+ return prompt_tokens , completion_tokens
0 commit comments