Skip to content

Commit 63c4a30

Browse files
committed
TestVertexAIGPTOSSTransformation
1 parent 35b501b commit 63c4a30

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/gpt_oss/test_vertex_ai_gpt_oss_transformation.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
import sys
4-
from unittest.mock import AsyncMock, MagicMock, patch
4+
from unittest.mock import MagicMock, patch
55

66
import httpx
77
import pytest
@@ -51,9 +51,12 @@ async def test_vertex_ai_gpt_oss_simple_request():
5151
with the correct request body.
5252
"""
5353
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
54+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
55+
VertexLLM,
56+
)
5457

5558
# Mock response
56-
mock_response = AsyncMock()
59+
mock_response = MagicMock()
5760
mock_response.status_code = 200
5861
mock_response.headers = {}
5962
mock_response.json.return_value = {
@@ -80,7 +83,11 @@ async def test_vertex_ai_gpt_oss_simple_request():
8083

8184
client = AsyncHTTPHandler()
8285

83-
with patch.object(client, "post", return_value=mock_response) as mock_post:
86+
async def mock_post_func(*args, **kwargs):
87+
return mock_response
88+
89+
with patch.object(client, "post", side_effect=mock_post_func) as mock_post, \
90+
patch.object(VertexLLM, "_ensure_access_token", return_value=("fake-token", "pathrise-convert-1606954137718")):
8491
response = await litellm.acompletion(
8592
model="vertex_ai/openai/gpt-oss-20b-maas",
8693
messages=[
@@ -103,7 +110,8 @@ async def test_vertex_ai_gpt_oss_simple_request():
103110

104111
# Get the call arguments
105112
call_args = mock_post.call_args
106-
called_url = call_args[0][0] # First positional argument is the URL
113+
# For side_effect, the URL is passed as kwargs['url']
114+
called_url = call_args.kwargs["url"]
107115
request_body = json.loads(call_args.kwargs["data"])
108116

109117
# Verify the URL
@@ -128,7 +136,7 @@ async def test_vertex_ai_gpt_oss_simple_request():
128136
assert request_body == expected_request_body
129137

130138
# Verify response structure
131-
assert response.model == "vertex_ai/openai/gpt-oss-20b-maas"
139+
assert response.model == "openai/gpt-oss-20b-maas"
132140
assert len(response.choices) == 1
133141
assert response.choices[0].message.role == "assistant"
134142

@@ -140,9 +148,12 @@ async def test_vertex_ai_gpt_oss_reasoning_effort():
140148
for GPT-OSS models.
141149
"""
142150
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
151+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
152+
VertexLLM,
153+
)
143154

144155
# Mock response
145-
mock_response = AsyncMock()
156+
mock_response = MagicMock()
146157
mock_response.status_code = 200
147158
mock_response.headers = {}
148159
mock_response.json.return_value = {
@@ -169,7 +180,11 @@ async def test_vertex_ai_gpt_oss_reasoning_effort():
169180

170181
client = AsyncHTTPHandler()
171182

172-
with patch.object(client, "post", return_value=mock_response) as mock_post:
183+
async def mock_post_func(*args, **kwargs):
184+
return mock_response
185+
186+
with patch.object(client, "post", side_effect=mock_post_func) as mock_post, \
187+
patch.object(VertexLLM, "_ensure_access_token", return_value=("fake-token", "pathrise-convert-1606954137718")):
173188
response = await litellm.acompletion(
174189
model="vertex_ai/openai/gpt-oss-20b-maas",
175190
messages=[
@@ -218,6 +233,6 @@ async def test_vertex_ai_gpt_oss_reasoning_effort():
218233
assert request_body == expected_request_body
219234

220235
# Verify response structure
221-
assert response.model == "vertex_ai/openai/gpt-oss-20b-maas"
236+
assert response.model == "openai/gpt-oss-20b-maas"
222237
assert len(response.choices) == 1
223238
assert response.choices[0].message.role == "assistant"

0 commit comments

Comments
 (0)