Skip to content

Commit 89fe402

Browse files
Add comprehensive test suite for Responses API (#20)
The project had almost no test coverage - just a single test checking if the API returns 200. This adds proper testing infrastructure and 21 new tests covering the main API functionality. Tests now cover response creation, error handling, tools, sessions, performance, and usage tracking. All tests passing.
1 parent 9074326 commit 89fe402

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed

tests/conftest.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import sys
3+
import pytest
4+
from typing import Generator, Any
5+
from unittest.mock import Mock, MagicMock
6+
from fastapi.testclient import TestClient
7+
8+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9+
10+
from openai_harmony import (
11+
HarmonyEncodingName,
12+
load_harmony_encoding,
13+
)
14+
from gpt_oss.responses_api.api_server import create_api_server
15+
16+
17+
@pytest.fixture(scope="session")
18+
def harmony_encoding():
19+
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
20+
21+
22+
@pytest.fixture
23+
def mock_infer_token(harmony_encoding):
24+
fake_tokens = harmony_encoding.encode(
25+
"<|channel|>final<|message|>Test response<|return|>",
26+
allowed_special="all"
27+
)
28+
token_queue = fake_tokens.copy()
29+
30+
def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int:
31+
nonlocal token_queue
32+
if len(token_queue) == 0:
33+
token_queue = fake_tokens.copy()
34+
return token_queue.pop(0)
35+
return _mock_infer
36+
37+
38+
@pytest.fixture
39+
def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]:
40+
app = create_api_server(
41+
infer_next_token=mock_infer_token,
42+
encoding=harmony_encoding
43+
)
44+
with TestClient(app) as client:
45+
yield client
46+
47+
48+
@pytest.fixture
49+
def sample_request_data():
50+
return {
51+
"model": "gpt-oss-120b",
52+
"input": "Hello, how can I help you today?",
53+
"stream": False,
54+
"reasoning_effort": "low",
55+
"temperature": 0.7,
56+
"tools": []
57+
}
58+
59+
60+
@pytest.fixture
61+
def mock_browser_tool():
62+
mock = MagicMock()
63+
mock.search.return_value = ["Result 1", "Result 2"]
64+
mock.open_page.return_value = "Page content"
65+
mock.find_on_page.return_value = "Found text"
66+
return mock
67+
68+
69+
@pytest.fixture
70+
def mock_python_tool():
71+
mock = MagicMock()
72+
mock.execute.return_value = {
73+
"output": "print('Hello')",
74+
"error": None,
75+
"exit_code": 0
76+
}
77+
return mock
78+
79+
80+
@pytest.fixture(autouse=True)
81+
def reset_test_environment():
82+
test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH']
83+
original_values = {}
84+
85+
for var in test_env_vars:
86+
if var in os.environ:
87+
original_values[var] = os.environ[var]
88+
del os.environ[var]
89+
90+
yield
91+
92+
for var, value in original_values.items():
93+
os.environ[var] = value
94+
95+
96+
@pytest.fixture
97+
def performance_timer():
98+
import time
99+
100+
class Timer:
101+
def __init__(self):
102+
self.start_time = None
103+
self.end_time = None
104+
105+
def start(self):
106+
self.start_time = time.time()
107+
108+
def stop(self):
109+
self.end_time = time.time()
110+
return self.elapsed
111+
112+
@property
113+
def elapsed(self):
114+
if self.start_time and self.end_time:
115+
return self.end_time - self.start_time
116+
return None
117+
118+
return Timer()

tests/test_api_endpoints.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import pytest
2+
import json
3+
import asyncio
4+
from fastapi import status
5+
from unittest.mock import patch, MagicMock, AsyncMock
6+
7+
8+
class TestResponsesEndpoint:
9+
10+
def test_basic_response_creation(self, api_client, sample_request_data):
11+
response = api_client.post("/v1/responses", json=sample_request_data)
12+
assert response.status_code == status.HTTP_200_OK
13+
data = response.json()
14+
assert "id" in data
15+
assert data["object"] == "response"
16+
assert data["model"] == sample_request_data["model"]
17+
18+
def test_response_with_high_reasoning(self, api_client, sample_request_data):
19+
sample_request_data["reasoning_effort"] = "high"
20+
response = api_client.post("/v1/responses", json=sample_request_data)
21+
assert response.status_code == status.HTTP_200_OK
22+
data = response.json()
23+
assert "id" in data
24+
assert data["status"] == "completed"
25+
26+
def test_response_with_medium_reasoning(self, api_client, sample_request_data):
27+
sample_request_data["reasoning_effort"] = "medium"
28+
response = api_client.post("/v1/responses", json=sample_request_data)
29+
assert response.status_code == status.HTTP_200_OK
30+
data = response.json()
31+
assert "id" in data
32+
assert data["status"] == "completed"
33+
34+
def test_response_with_invalid_model(self, api_client, sample_request_data):
35+
sample_request_data["model"] = "invalid-model"
36+
response = api_client.post("/v1/responses", json=sample_request_data)
37+
# Should still accept but might handle differently
38+
assert response.status_code == status.HTTP_200_OK
39+
40+
def test_response_with_empty_input(self, api_client, sample_request_data):
41+
sample_request_data["input"] = ""
42+
response = api_client.post("/v1/responses", json=sample_request_data)
43+
assert response.status_code == status.HTTP_200_OK
44+
45+
def test_response_with_tools(self, api_client, sample_request_data):
46+
sample_request_data["tools"] = [
47+
{
48+
"type": "browser_search"
49+
}
50+
]
51+
response = api_client.post("/v1/responses", json=sample_request_data)
52+
assert response.status_code == status.HTTP_200_OK
53+
54+
def test_response_with_custom_temperature(self, api_client, sample_request_data):
55+
for temp in [0.0, 0.5, 1.0, 1.5, 2.0]:
56+
sample_request_data["temperature"] = temp
57+
response = api_client.post("/v1/responses", json=sample_request_data)
58+
assert response.status_code == status.HTTP_200_OK
59+
data = response.json()
60+
assert "usage" in data
61+
62+
def test_streaming_response(self, api_client, sample_request_data):
63+
sample_request_data["stream"] = True
64+
with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response:
65+
assert response.status_code == status.HTTP_200_OK
66+
# Verify we get SSE events
67+
for line in response.iter_lines():
68+
if line and line.startswith("data: "):
69+
event_data = line[6:] # Remove "data: " prefix
70+
if event_data != "[DONE]":
71+
json.loads(event_data) # Should be valid JSON
72+
break
73+
74+
75+
class TestResponsesWithSession:
76+
77+
def test_response_with_session_id(self, api_client, sample_request_data):
78+
session_id = "test-session-123"
79+
sample_request_data["session_id"] = session_id
80+
81+
# First request
82+
response1 = api_client.post("/v1/responses", json=sample_request_data)
83+
assert response1.status_code == status.HTTP_200_OK
84+
data1 = response1.json()
85+
86+
# Second request with same session
87+
sample_request_data["input"] = "Follow up question"
88+
response2 = api_client.post("/v1/responses", json=sample_request_data)
89+
assert response2.status_code == status.HTTP_200_OK
90+
data2 = response2.json()
91+
92+
# Should have different response IDs
93+
assert data1["id"] != data2["id"]
94+
95+
def test_response_continuation(self, api_client, sample_request_data):
96+
# Create initial response
97+
response1 = api_client.post("/v1/responses", json=sample_request_data)
98+
assert response1.status_code == status.HTTP_200_OK
99+
data1 = response1.json()
100+
response_id = data1["id"]
101+
102+
# Continue the response
103+
continuation_request = {
104+
"model": sample_request_data["model"],
105+
"response_id": response_id,
106+
"input": "Continue the previous thought"
107+
}
108+
response2 = api_client.post("/v1/responses", json=continuation_request)
109+
assert response2.status_code == status.HTTP_200_OK
110+
111+
112+
class TestErrorHandling:
113+
114+
def test_missing_required_fields(self, api_client):
115+
# Model field has default, so test with empty JSON
116+
response = api_client.post("/v1/responses", json={})
117+
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
118+
119+
def test_invalid_reasoning_effort(self, api_client, sample_request_data):
120+
sample_request_data["reasoning_effort"] = "invalid"
121+
response = api_client.post("/v1/responses", json=sample_request_data)
122+
# May handle gracefully or return error
123+
assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY]
124+
125+
def test_malformed_json(self, api_client):
126+
response = api_client.post(
127+
"/v1/responses",
128+
data="not json",
129+
headers={"Content-Type": "application/json"}
130+
)
131+
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
132+
133+
def test_extremely_long_input(self, api_client, sample_request_data):
134+
# Test with very long input
135+
sample_request_data["input"] = "x" * 100000
136+
response = api_client.post("/v1/responses", json=sample_request_data)
137+
assert response.status_code == status.HTTP_200_OK
138+
139+
140+
class TestToolIntegration:
141+
142+
def test_browser_search_tool(self, api_client, sample_request_data):
143+
sample_request_data["tools"] = [
144+
{
145+
"type": "browser_search"
146+
}
147+
]
148+
response = api_client.post("/v1/responses", json=sample_request_data)
149+
assert response.status_code == status.HTTP_200_OK
150+
151+
def test_function_tool_integration(self, api_client, sample_request_data):
152+
sample_request_data["tools"] = [
153+
{
154+
"type": "function",
155+
"name": "test_function",
156+
"parameters": {"type": "object", "properties": {}},
157+
"description": "Test function"
158+
}
159+
]
160+
response = api_client.post("/v1/responses", json=sample_request_data)
161+
assert response.status_code == status.HTTP_200_OK
162+
163+
def test_multiple_tools(self, api_client, sample_request_data):
164+
sample_request_data["tools"] = [
165+
{
166+
"type": "browser_search"
167+
},
168+
{
169+
"type": "function",
170+
"name": "test_function",
171+
"parameters": {"type": "object", "properties": {}},
172+
"description": "Test function"
173+
}
174+
]
175+
response = api_client.post("/v1/responses", json=sample_request_data)
176+
assert response.status_code == status.HTTP_200_OK
177+
178+
179+
class TestPerformance:
180+
181+
def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer):
182+
performance_timer.start()
183+
response = api_client.post("/v1/responses", json=sample_request_data)
184+
elapsed = performance_timer.stop()
185+
186+
assert response.status_code == status.HTTP_200_OK
187+
# Response should be reasonably fast for mock inference
188+
assert elapsed < 5.0 # 5 seconds threshold
189+
190+
def test_multiple_sequential_requests(self, api_client, sample_request_data):
191+
# Test multiple requests work correctly
192+
for i in range(3):
193+
data = sample_request_data.copy()
194+
data["input"] = f"Request {i}"
195+
response = api_client.post("/v1/responses", json=data)
196+
assert response.status_code == status.HTTP_200_OK
197+
198+
199+
class TestUsageTracking:
200+
201+
def test_usage_object_structure(self, api_client, sample_request_data):
202+
response = api_client.post("/v1/responses", json=sample_request_data)
203+
assert response.status_code == status.HTTP_200_OK
204+
data = response.json()
205+
206+
assert "usage" in data
207+
usage = data["usage"]
208+
assert "input_tokens" in usage
209+
assert "output_tokens" in usage
210+
assert "total_tokens" in usage
211+
# reasoning_tokens may not always be present
212+
# assert "reasoning_tokens" in usage
213+
214+
# Basic validation
215+
assert usage["input_tokens"] >= 0
216+
assert usage["output_tokens"] >= 0
217+
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]
218+
219+
def test_usage_increases_with_longer_input(self, api_client, sample_request_data):
220+
# Short input
221+
response1 = api_client.post("/v1/responses", json=sample_request_data)
222+
usage1 = response1.json()["usage"]
223+
224+
# Longer input
225+
sample_request_data["input"] = sample_request_data["input"] * 10
226+
response2 = api_client.post("/v1/responses", json=sample_request_data)
227+
usage2 = response2.json()["usage"]
228+
229+
# Longer input should use more tokens
230+
assert usage2["input_tokens"] > usage1["input_tokens"]

0 commit comments

Comments
 (0)