Skip to content

Fix memory leak in Python extension module initialization #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions gpt_oss/metal/python/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,50 @@ PyMODINIT_FUNC PyInit__metal(void) {
PyObject* context_type = NULL;

if (PyType_Ready(&PyGPTOSSModel_Type) < 0) {
goto error;
goto cleanup;
}
model_type = (PyObject*) &PyGPTOSSModel_Type;
Py_INCREF(model_type);

if (PyType_Ready(&PyGPTOSSTokenizer_Type) < 0) {
goto error;
goto cleanup;
}
tokenizer_type = (PyObject*) &PyGPTOSSTokenizer_Type;
Py_INCREF(tokenizer_type);

if (PyType_Ready(&PyGPTOSSContext_Type) < 0) {
goto error;
goto cleanup;
}
context_type = (PyObject*) &PyGPTOSSContext_Type;
Py_INCREF(context_type);

module = PyModule_Create(&metal_module);
if (module == NULL) {
goto error;
goto cleanup;
}

if (PyModule_AddObject(module, "Model", model_type) < 0) {
goto error;
// Use PyModule_AddObjectRef to handle reference counting correctly
if (PyModule_AddObjectRef(module, "Model", model_type) < 0) {
goto cleanup;
}

if (PyModule_AddObject(module, "Tokenizer", tokenizer_type) < 0) {
goto error;
if (PyModule_AddObjectRef(module, "Tokenizer", tokenizer_type) < 0) {
goto cleanup;
}

if (PyModule_AddObject(module, "Context", context_type) < 0) {
goto error;
if (PyModule_AddObjectRef(module, "Context", context_type) < 0) {
goto cleanup;
}

return module;
// Successfully added all objects to module, set pointers to NULL to avoid double-decref in cleanup
model_type = NULL;
tokenizer_type = NULL;
context_type = NULL;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicating deallocation code is undesirable. Instead, rename error: label to cleanup: and make it execute on both success and failure, similarly to

cleanup:
if (fd != -1) {
close(fd);
fd = -1;
}
gptoss_model_release(model); // does nothing if model is NULL
gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
return status;
}

error:
cleanup:
// Clean up type object references - does nothing if pointers are NULL
Py_XDECREF(context_type);
Py_XDECREF(tokenizer_type);
Py_XDECREF(model_type);
Py_XDECREF(module);
return NULL;
return module;
}
37 changes: 37 additions & 0 deletions tests/test_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class TestResponsesEndpoint:

def test_basic_response_creation(self, api_client, sample_request_data):
"""Test that a basic response can be created successfully."""
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
Expand All @@ -16,6 +17,7 @@ def test_basic_response_creation(self, api_client, sample_request_data):
assert data["model"] == sample_request_data["model"]

def test_response_with_high_reasoning(self, api_client, sample_request_data):
"""Test response creation with high reasoning effort."""
sample_request_data["reasoning_effort"] = "high"
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
Expand All @@ -24,6 +26,7 @@ def test_response_with_high_reasoning(self, api_client, sample_request_data):
assert data["status"] == "completed"

def test_response_with_medium_reasoning(self, api_client, sample_request_data):
"""Test response creation with medium reasoning effort."""
sample_request_data["reasoning_effort"] = "medium"
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
Expand All @@ -32,17 +35,20 @@ def test_response_with_medium_reasoning(self, api_client, sample_request_data):
assert data["status"] == "completed"

def test_response_with_invalid_model(self, api_client, sample_request_data):
"""Test response when an invalid model is specified."""
sample_request_data["model"] = "invalid-model"
response = api_client.post("/v1/responses", json=sample_request_data)
# Should still accept but might handle differently
assert response.status_code == status.HTTP_200_OK

def test_response_with_empty_input(self, api_client, sample_request_data):
"""Test response when input is an empty string."""
sample_request_data["input"] = ""
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK

def test_response_with_tools(self, api_client, sample_request_data):
"""Test response when tools are specified in the request."""
sample_request_data["tools"] = [
{
"type": "browser_search"
Expand All @@ -52,6 +58,7 @@ def test_response_with_tools(self, api_client, sample_request_data):
assert response.status_code == status.HTTP_200_OK

def test_response_with_custom_temperature(self, api_client, sample_request_data):
"""Test response with various temperature values."""
for temp in [0.0, 0.5, 1.0, 1.5, 2.0]:
sample_request_data["temperature"] = temp
response = api_client.post("/v1/responses", json=sample_request_data)
Expand All @@ -60,6 +67,7 @@ def test_response_with_custom_temperature(self, api_client, sample_request_data)
assert "usage" in data

def test_streaming_response(self, api_client, sample_request_data):
"""Test streaming response functionality using SSE."""
sample_request_data["stream"] = True
with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response:
assert response.status_code == status.HTTP_200_OK
Expand All @@ -75,6 +83,7 @@ def test_streaming_response(self, api_client, sample_request_data):
class TestResponsesWithSession:

def test_response_with_session_id(self, api_client, sample_request_data):
"""Test that responses with the same session_id are handled correctly."""
session_id = "test-session-123"
sample_request_data["session_id"] = session_id

Expand All @@ -93,6 +102,7 @@ def test_response_with_session_id(self, api_client, sample_request_data):
assert data1["id"] != data2["id"]

def test_response_continuation(self, api_client, sample_request_data):
"""Test continuation of a previous response using response_id."""
# Create initial response
response1 = api_client.post("/v1/responses", json=sample_request_data)
assert response1.status_code == status.HTTP_200_OK
Expand All @@ -110,19 +120,38 @@ def test_response_continuation(self, api_client, sample_request_data):


class TestErrorHandling:
def test_invalid_types_in_payload(self, api_client, sample_request_data):
# Set input to an integer instead of a string
sample_request_data["input"] = 12345
response = api_client.post("/v1/responses", json=sample_request_data)
# Should return 422 Unprocessable Entity or 400 Bad Request
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST]

# Set model to a list instead of a string
sample_request_data["model"] = ["gpt-oss-20b"]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST]

# Set tools to a string instead of a list
sample_request_data["tools"] = "browser_search"
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST]

def test_missing_required_fields(self, api_client):
"""Test response when required fields are missing from the request."""
# Model field has default, so test with empty JSON
response = api_client.post("/v1/responses", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

def test_invalid_reasoning_effort(self, api_client, sample_request_data):
"""Test response when reasoning_effort has an invalid value."""
sample_request_data["reasoning_effort"] = "invalid"
response = api_client.post("/v1/responses", json=sample_request_data)
# May handle gracefully or return error
assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY]

def test_malformed_json(self, api_client):
"""Test response when the request body is not valid JSON."""
response = api_client.post(
"/v1/responses",
data="not json",
Expand All @@ -131,6 +160,7 @@ def test_malformed_json(self, api_client):
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

def test_extremely_long_input(self, api_client, sample_request_data):
"""Test response when input is extremely long."""
# Test with very long input
sample_request_data["input"] = "x" * 100000
response = api_client.post("/v1/responses", json=sample_request_data)
Expand All @@ -140,6 +170,7 @@ def test_extremely_long_input(self, api_client, sample_request_data):
class TestToolIntegration:

def test_browser_search_tool(self, api_client, sample_request_data):
"""Test response when browser_search tool is specified."""
sample_request_data["tools"] = [
{
"type": "browser_search"
Expand All @@ -149,6 +180,7 @@ def test_browser_search_tool(self, api_client, sample_request_data):
assert response.status_code == status.HTTP_200_OK

def test_function_tool_integration(self, api_client, sample_request_data):
"""Test response when a function tool is specified."""
sample_request_data["tools"] = [
{
"type": "function",
Expand All @@ -161,6 +193,7 @@ def test_function_tool_integration(self, api_client, sample_request_data):
assert response.status_code == status.HTTP_200_OK

def test_multiple_tools(self, api_client, sample_request_data):
"""Test response when multiple tools are specified."""
sample_request_data["tools"] = [
{
"type": "browser_search"
Expand All @@ -179,6 +212,7 @@ def test_multiple_tools(self, api_client, sample_request_data):
class TestPerformance:

def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer):
"""Test that response time is under a specified threshold."""
performance_timer.start()
response = api_client.post("/v1/responses", json=sample_request_data)
elapsed = performance_timer.stop()
Expand All @@ -188,6 +222,7 @@ def test_response_time_under_threshold(self, api_client, sample_request_data, pe
assert elapsed < 5.0 # 5 seconds threshold

def test_multiple_sequential_requests(self, api_client, sample_request_data):
"""Test that multiple sequential requests are handled correctly."""
# Test multiple requests work correctly
for i in range(3):
data = sample_request_data.copy()
Expand All @@ -199,6 +234,7 @@ def test_multiple_sequential_requests(self, api_client, sample_request_data):
class TestUsageTracking:

def test_usage_object_structure(self, api_client, sample_request_data):
"""Test that the usage object in the response has the correct structure and values."""
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
Expand All @@ -217,6 +253,7 @@ def test_usage_object_structure(self, api_client, sample_request_data):
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]

def test_usage_increases_with_longer_input(self, api_client, sample_request_data):
"""Test that input_tokens increases as the input length increases."""
# Short input
response1 = api_client.post("/v1/responses", json=sample_request_data)
usage1 = response1.json()["usage"]
Expand Down