Skip to content

Commit 002c2f1

Browse files
committed
fix token count error when proxy gemini cli to openai like model
1 parent 34275ab commit 002c2f1

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

litellm/google_genai/adapters/handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def _prepare_completion_kwargs(
3737

3838
completion_kwargs: Dict[str, Any] = dict(completion_request)
3939

40+
# feed metadata for custom callback
41+
# if 'metadata' in extra_kwargs:
42+
# completion_kwargs['metadata'] = extra_kwargs['metadata']
43+
4044
if stream:
4145
completion_kwargs["stream"] = stream
4246

litellm/proxy/google_endpoints/endpoints.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,22 @@ async def google_count_tokens(request: Request, model_name: str):
173173
"""
174174
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
175175
from litellm.proxy.proxy_server import token_counter as internal_token_counter
176+
from litellm.google_genai.adapters.transformation import GoogleGenAIAdapter
176177

177178
data = await _read_request_body(request=request)
178179
contents = data.get("contents", [])
179180
#Create TokenCountRequest for the internal endpoint
180181
from litellm.proxy._types import TokenCountRequest
181182

183+
# Translate contents to openai format messages using the adapter
184+
messages = (GoogleGenAIAdapter()
185+
.translate_generate_content_to_completion(model_name, contents)
186+
.get("messages", []))
187+
182188
token_request = TokenCountRequest(
183189
model=model_name,
184-
contents=contents
190+
contents=contents,
191+
messages=messages, # compatibility when use openai-like endpoint
185192
)
186193

187194
# Call the internal token counter function with direct request flag set to False
@@ -192,10 +199,16 @@ async def google_count_tokens(request: Request, model_name: str):
192199
if token_response is not None:
193200
# cast the response to the well known format
194201
original_response: dict = token_response.original_response or {}
195-
return TokenCountDetailsResponse(
196-
totalTokens=original_response.get("totalTokens", 0),
197-
promptTokensDetails=original_response.get("promptTokensDetails", []),
198-
)
202+
if original_response:
203+
return TokenCountDetailsResponse(
204+
totalTokens=original_response.get("totalTokens", 0),
205+
promptTokensDetails=original_response.get("promptTokensDetails", []),
206+
)
207+
else:
208+
return TokenCountDetailsResponse(
209+
totalTokens=token_response.total_tokens or 0,
210+
promptTokensDetails=[],
211+
)
199212

200213
#########################################################
201214
# Return the response in the well known format

tests/test_litellm/proxy/google_endpoints/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Test for google_endpoints/endpoints.py
3+
"""
4+
import pytest
5+
import sys, os
6+
from dotenv import load_dotenv
7+
8+
9+
from litellm.proxy.google_endpoints.endpoints import google_count_tokens
10+
from litellm.types.llms.vertex_ai import TokenCountDetailsResponse
11+
from starlette.requests import Request
12+
13+
load_dotenv()
14+
15+
sys.path.insert(
16+
0, os.path.abspath("../../../..")
17+
)
18+
19+
@pytest.mark.asyncio
20+
async def test_proxy_gemini_to_openai_like_model_token_counting():
21+
"""
22+
Test the token counting endpoint for proxing gemini to openai-like models.
23+
"""
24+
response: TokenCountDetailsResponse = await google_count_tokens(
25+
request=Request(
26+
scope={
27+
"type": "http",
28+
"parsed_body": (
29+
[
30+
"contents"
31+
],
32+
{
33+
"contents": [
34+
{
35+
"parts": [
36+
{
37+
"text": "Hello, how are you?"
38+
}
39+
]
40+
}
41+
]
42+
}
43+
)
44+
}
45+
),
46+
model_name="volcengine/foo",
47+
)
48+
49+
assert response.get("totalTokens") > 0

0 commit comments

Comments
 (0)