22
33from llama_index .core .callbacks .schema import EventPayload
44from llama_index .core .utilities .token_counting import TokenCounter
5+ from openai .types import CompletionUsage
6+ from openai .types .chat import ChatCompletion
57
68
79class Constants :
@@ -10,6 +12,9 @@ class Constants:
1012 KEY_EVAL_COUNT = "eval_count"
1113 KEY_PROMPT_EVAL_COUNT = "prompt_eval_count"
1214 KEY_RAW_RESPONSE = "_raw_response"
15+ KEY_TEXT_TOKEN_COUNT = "inputTextTokenCount"
16+ KEY_TOKEN_COUNT = "tokenCount"
17+ KEY_RESULTS = "results"
1318 INPUT_TOKENS = "input_tokens"
1419 OUTPUT_TOKENS = "output_tokens"
1520 PROMPT_TOKENS = "prompt_tokens"
@@ -32,62 +37,91 @@ def __init__(self, input_tokens, output_tokens):
3237
3338 @staticmethod
3439 def get_llm_token_counts (payload : dict [str , Any ]) -> TokenCounter :
35- token_counter = TokenCounter (
36- input_tokens = Constants .DEFAULT_TOKEN_COUNT ,
37- output_tokens = Constants .DEFAULT_TOKEN_COUNT ,
38- )
40+ prompt_tokens = Constants .DEFAULT_TOKEN_COUNT
41+ completion_tokens = Constants .DEFAULT_TOKEN_COUNT
3942 if EventPayload .PROMPT in payload :
4043 completion_raw = payload .get (EventPayload .COMPLETION ).raw
4144 if completion_raw :
42- if hasattr (completion_raw , Constants .KEY_USAGE ):
43- token_counts : dict [
44- str , int
45- ] = TokenCounter ._get_prompt_completion_tokens (completion_raw )
46- token_counter = TokenCounter (
47- input_tokens = token_counts [Constants .PROMPT_TOKENS ],
48- output_tokens = token_counts [Constants .COMPLETION_TOKENS ],
49- )
50- elif hasattr (completion_raw , Constants .KEY_RAW_RESPONSE ):
51- if hasattr (
52- completion_raw ._raw_response ,
53- Constants .KEY_USAGE_METADATA ,
54- ):
55- usage = completion_raw ._raw_response .usage_metadata
56- token_counter = TokenCounter (
57- input_tokens = usage .prompt_token_count ,
58- output_tokens = usage .candidates_token_count ,
59- )
60- else :
61- prompt_tokens = Constants .DEFAULT_TOKEN_COUNT
62- completion_tokens = Constants .DEFAULT_TOKEN_COUNT
63- if hasattr (completion_raw , Constants .KEY_PROMPT_EVAL_COUNT ):
64- prompt_tokens = completion_raw .prompt_eval_count
65- if hasattr (completion_raw , Constants .KEY_EVAL_COUNT ):
66- completion_tokens = completion_raw .eval_count
67- token_counter = TokenCounter (
68- input_tokens = prompt_tokens ,
69- output_tokens = completion_tokens ,
70- )
45+ # For Open AI models, token count is part of ChatCompletion
46+ if isinstance (completion_raw , ChatCompletion ):
47+ if hasattr (completion_raw , Constants .KEY_USAGE ):
48+ token_counts : dict [
49+ str , int
50+ ] = TokenCounter ._get_prompt_completion_tokens (completion_raw )
51+ prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
52+ completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
53+ # For other models
54+ elif isinstance (completion_raw , dict ):
55+ # For Gemini models
56+ if completion_raw .get (Constants .KEY_RAW_RESPONSE ):
57+ if hasattr (
58+ completion_raw .get (Constants .KEY_RAW_RESPONSE ),
59+ Constants .KEY_USAGE_METADATA ,
60+ ):
61+ usage = completion_raw .get (
62+ Constants .KEY_RAW_RESPONSE
63+ ).usage_metadata
64+ prompt_tokens = usage .prompt_token_count
65+ completion_tokens = usage .candidates_token_count
66+ elif completion_raw .get (Constants .KEY_USAGE ):
67+ token_counts : dict [
68+ str , int
69+ ] = TokenCounter ._get_prompt_completion_tokens (completion_raw )
70+ prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
71+ completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
72+ # For Bedrock models
73+ elif Constants .KEY_TEXT_TOKEN_COUNT in completion_raw :
74+ prompt_tokens = completion_raw [Constants .KEY_TEXT_TOKEN_COUNT ]
75+ if Constants .KEY_RESULTS in completion_raw :
76+ result_list : list = completion_raw [Constants .KEY_RESULTS ]
77+ if len (result_list ) > 0 :
78+ result : dict = result_list [0 ]
79+ if Constants .KEY_TOKEN_COUNT in result :
80+ completion_tokens = result .get (
81+ Constants .KEY_TOKEN_COUNT
82+ )
83+ else :
84+ if completion_raw .get (Constants .KEY_PROMPT_EVAL_COUNT ):
85+ prompt_tokens = completion_raw .get (
86+ Constants .KEY_PROMPT_EVAL_COUNT
87+ )
88+ if completion_raw .get (Constants .KEY_EVAL_COUNT ):
89+ completion_tokens = completion_raw .get (
90+ Constants .KEY_EVAL_COUNT
91+ )
92+ # For Anthropic models
7193 elif EventPayload .MESSAGES in payload :
7294 response_raw = payload .get (EventPayload .RESPONSE ).raw
7395 if response_raw :
7496 token_counts : dict [
7597 str , int
7698 ] = TokenCounter ._get_prompt_completion_tokens (response_raw )
77- token_counter = TokenCounter (
78- input_tokens = token_counts [Constants .PROMPT_TOKENS ],
79- output_tokens = token_counts [Constants .COMPLETION_TOKENS ],
80- )
99+ prompt_tokens = token_counts [Constants .PROMPT_TOKENS ]
100+ completion_tokens = token_counts [Constants .COMPLETION_TOKENS ]
81101
102+ token_counter = TokenCounter (
103+ input_tokens = prompt_tokens ,
104+ output_tokens = completion_tokens ,
105+ )
82106 return token_counter
83107
84108 @staticmethod
85109 def _get_prompt_completion_tokens (response ) -> dict [str , int ]:
110+ usage = None
86111 prompt_tokens = Constants .DEFAULT_TOKEN_COUNT
87112 completion_tokens = Constants .DEFAULT_TOKEN_COUNT
88-
89- if hasattr (response , Constants .KEY_USAGE ):
113+ # For OpenAI models,response is an obj of CompletionUsage
114+ if (
115+ isinstance (response , ChatCompletion )
116+ and hasattr (response , Constants .KEY_USAGE )
117+ and isinstance (response .usage , CompletionUsage )
118+ ):
90119 usage = response .usage
120+ # For LLM models other than OpenAI, response is a dict
121+ elif isinstance (response , dict ) and Constants .KEY_USAGE in response :
122+ usage = response .get (Constants .KEY_USAGE )
123+
124+ if usage :
91125 if hasattr (usage , Constants .INPUT_TOKENS ):
92126 prompt_tokens = usage .input_tokens
93127 elif hasattr (usage , Constants .PROMPT_TOKENS ):
0 commit comments