@@ -99,7 +99,7 @@ def get_num_tokens_from_messages(
99
99
except Exception as e :
100
100
tokenizer = TokenizerManage .get_tokenizer ()
101
101
return sum ([len (tokenizer .encode (get_buffer_string ([m ]))) for m in messages ])
102
- return self .usage_metadata .get ('input_tokens' , 0 )
102
+ return self .usage_metadata .get ('input_tokens' , self . usage_metadata . get ( 'prompt_tokens' , 0 ) )
103
103
104
104
def get_num_tokens (self , text : str ) -> int :
105
105
if self .usage_metadata is None or self .usage_metadata == {}:
@@ -108,7 +108,8 @@ def get_num_tokens(self, text: str) -> int:
108
108
except Exception as e :
109
109
tokenizer = TokenizerManage .get_tokenizer ()
110
110
return len (tokenizer .encode (text ))
111
- return self .get_last_generation_info ().get ('output_tokens' , 0 )
111
+ return self .get_last_generation_info ().get ('output_tokens' ,
112
+ self .get_last_generation_info ().get ('completion_tokens' , 0 ))
112
113
113
114
def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
114
115
kwargs ['stream_usage' ] = True
@@ -133,7 +134,7 @@ def _convert_chunk_to_generation_chunk(
133
134
)
134
135
135
136
usage_metadata : Optional [UsageMetadata ] = (
136
- _create_usage_metadata (token_usage ) if token_usage else None
137
+ _create_usage_metadata (token_usage ) if token_usage and token_usage . get ( "prompt_tokens" ) else None
137
138
)
138
139
if len (choices ) == 0 :
139
140
# logprobs is implicitly None
0 commit comments