Skip to content

Commit d2c519f

Browse files
Merge pull request #14133 from retanoj/fix/gemini-token-count
Fix token count error for gemini cli
2 parents 5ab3d74 + b1686ec commit d2c519f

File tree

4 files changed

+74
-6
lines changed

4 files changed

+74
-6
lines changed

litellm/google_genai/adapters/handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def _prepare_completion_kwargs(
3737

3838
completion_kwargs: Dict[str, Any] = dict(completion_request)
3939

40+
# feed metadata for custom callback
41+
if extra_kwargs is not None and "metadata" in extra_kwargs:
42+
completion_kwargs["metadata"] = extra_kwargs["metadata"]
43+
4044
if stream:
4145
completion_kwargs["stream"] = stream
4246

litellm/proxy/google_endpoints/endpoints.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,24 @@ async def google_count_tokens(request: Request, model_name: str):
173173
"""
174174
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
175175
from litellm.proxy.proxy_server import token_counter as internal_token_counter
176+
from litellm.google_genai.adapters.transformation import GoogleGenAIAdapter
176177

177178
data = await _read_request_body(request=request)
178179
contents = data.get("contents", [])
179180
#Create TokenCountRequest for the internal endpoint
180181
from litellm.proxy._types import TokenCountRequest
181182

183+
# Translate contents to openai format messages using the adapter
184+
messages = (
185+
GoogleGenAIAdapter()
186+
.translate_generate_content_to_completion(model_name, contents)
187+
.get("messages", [])
188+
)
189+
182190
token_request = TokenCountRequest(
183191
model=model_name,
184-
contents=contents
192+
contents=contents,
193+
messages=messages, # compatibility when use openai-like endpoint
185194
)
186195

187196
# Call the internal token counter function with direct request flag set to False
@@ -192,11 +201,17 @@ async def google_count_tokens(request: Request, model_name: str):
192201
if token_response is not None:
193202
# cast the response to the well known format
194203
original_response: dict = token_response.original_response or {}
195-
return TokenCountDetailsResponse(
196-
totalTokens=original_response.get("totalTokens", 0),
197-
promptTokensDetails=original_response.get("promptTokensDetails", []),
198-
)
199-
204+
if original_response:
205+
return TokenCountDetailsResponse(
206+
totalTokens=original_response.get("totalTokens", 0),
207+
promptTokensDetails=original_response.get("promptTokensDetails", []),
208+
)
209+
else:
210+
return TokenCountDetailsResponse(
211+
totalTokens=token_response.total_tokens or 0,
212+
promptTokensDetails=[],
213+
)
214+
200215
#########################################################
201216
# Return the response in the well known format
202217
#########################################################

tests/test_litellm/proxy/google_endpoints/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Test for google_endpoints/endpoints.py
3+
"""
4+
import pytest
5+
import sys, os
6+
from dotenv import load_dotenv
7+
8+
9+
from litellm.proxy.google_endpoints.endpoints import google_count_tokens
10+
from litellm.types.llms.vertex_ai import TokenCountDetailsResponse
11+
from starlette.requests import Request
12+
13+
load_dotenv()
14+
15+
sys.path.insert(
16+
0, os.path.abspath("../../../..")
17+
)
18+
19+
@pytest.mark.asyncio
20+
async def test_proxy_gemini_to_openai_like_model_token_counting():
21+
"""
22+
Test the token counting endpoint for proxing gemini to openai-like models.
23+
"""
24+
response: TokenCountDetailsResponse = await google_count_tokens(
25+
request=Request(
26+
scope={
27+
"type": "http",
28+
"parsed_body": (
29+
[
30+
"contents"
31+
],
32+
{
33+
"contents": [
34+
{
35+
"parts": [
36+
{
37+
"text": "Hello, how are you?"
38+
}
39+
]
40+
}
41+
]
42+
}
43+
)
44+
}
45+
),
46+
model_name="volcengine/foo",
47+
)
48+
49+
assert response.get("totalTokens") > 0

0 commit comments

Comments
 (0)