diff --git a/gpt_oss/metal/python/module.c b/gpt_oss/metal/python/module.c index 2910c8f..4da6b2d 100644 --- a/gpt_oss/metal/python/module.c +++ b/gpt_oss/metal/python/module.c @@ -22,46 +22,50 @@ PyMODINIT_FUNC PyInit__metal(void) { PyObject* context_type = NULL; if (PyType_Ready(&PyGPTOSSModel_Type) < 0) { - goto error; + goto cleanup; } model_type = (PyObject*) &PyGPTOSSModel_Type; Py_INCREF(model_type); if (PyType_Ready(&PyGPTOSSTokenizer_Type) < 0) { - goto error; + goto cleanup; } tokenizer_type = (PyObject*) &PyGPTOSSTokenizer_Type; Py_INCREF(tokenizer_type); if (PyType_Ready(&PyGPTOSSContext_Type) < 0) { - goto error; + goto cleanup; } context_type = (PyObject*) &PyGPTOSSContext_Type; Py_INCREF(context_type); module = PyModule_Create(&metal_module); if (module == NULL) { - goto error; + goto cleanup; } - if (PyModule_AddObject(module, "Model", model_type) < 0) { - goto error; + // Use PyModule_AddObjectRef to handle reference counting correctly + if (PyModule_AddObjectRef(module, "Model", model_type) < 0) { + goto cleanup; } - if (PyModule_AddObject(module, "Tokenizer", tokenizer_type) < 0) { - goto error; + if (PyModule_AddObjectRef(module, "Tokenizer", tokenizer_type) < 0) { + goto cleanup; } - if (PyModule_AddObject(module, "Context", context_type) < 0) { - goto error; + if (PyModule_AddObjectRef(module, "Context", context_type) < 0) { + goto cleanup; } - return module; + // Successfully added all objects to module, set pointers to NULL to avoid double-decref in cleanup + model_type = NULL; + tokenizer_type = NULL; + context_type = NULL; -error: +cleanup: + // Clean up type object references - does nothing if pointers are NULL Py_XDECREF(context_type); Py_XDECREF(tokenizer_type); Py_XDECREF(model_type); - Py_XDECREF(module); - return NULL; + return module; } diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 7fd354b..c88584a 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -8,6 +8,7 @@ class TestResponsesEndpoint: def test_basic_response_creation(self, api_client, sample_request_data): + """Test that a basic response can be created successfully.""" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -16,6 +17,7 @@ def test_basic_response_creation(self, api_client, sample_request_data): assert data["model"] == sample_request_data["model"] def test_response_with_high_reasoning(self, api_client, sample_request_data): + """Test response creation with high reasoning effort.""" sample_request_data["reasoning_effort"] = "high" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK @@ -24,6 +26,7 @@ def test_response_with_high_reasoning(self, api_client, sample_request_data): assert data["status"] == "completed" def test_response_with_medium_reasoning(self, api_client, sample_request_data): + """Test response creation with medium reasoning effort.""" sample_request_data["reasoning_effort"] = "medium" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK @@ -32,17 +35,20 @@ def test_response_with_medium_reasoning(self, api_client, sample_request_data): assert data["status"] == "completed" def test_response_with_invalid_model(self, api_client, sample_request_data): + """Test response when an invalid model is specified.""" sample_request_data["model"] = "invalid-model" response = api_client.post("/v1/responses", json=sample_request_data) # Should still accept but might handle differently assert response.status_code == status.HTTP_200_OK def test_response_with_empty_input(self, api_client, sample_request_data): + """Test response when input is an empty string.""" sample_request_data["input"] = "" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK def test_response_with_tools(self, api_client, sample_request_data): + """Test response when tools are specified in the request.""" sample_request_data["tools"] = [ { "type": "browser_search" @@ -52,6 +58,7 @@ def test_response_with_tools(self, api_client, sample_request_data): assert response.status_code == status.HTTP_200_OK def test_response_with_custom_temperature(self, api_client, sample_request_data): + """Test response with various temperature values.""" for temp in [0.0, 0.5, 1.0, 1.5, 2.0]: sample_request_data["temperature"] = temp response = api_client.post("/v1/responses", json=sample_request_data) @@ -60,6 +67,7 @@ def test_response_with_custom_temperature(self, api_client, sample_request_data) assert "usage" in data def test_streaming_response(self, api_client, sample_request_data): + """Test streaming response functionality using SSE.""" sample_request_data["stream"] = True with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response: assert response.status_code == status.HTTP_200_OK @@ -75,6 +83,7 @@ def test_streaming_response(self, api_client, sample_request_data): class TestResponsesWithSession: def test_response_with_session_id(self, api_client, sample_request_data): + """Test that responses with the same session_id are handled correctly.""" session_id = "test-session-123" sample_request_data["session_id"] = session_id @@ -93,6 +102,7 @@ def test_response_with_session_id(self, api_client, sample_request_data): assert data1["id"] != data2["id"] def test_response_continuation(self, api_client, sample_request_data): + """Test continuation of a previous response using response_id.""" # Create initial response response1 = api_client.post("/v1/responses", json=sample_request_data) assert response1.status_code == status.HTTP_200_OK @@ -110,19 +120,38 @@ def test_response_continuation(self, api_client, sample_request_data): class TestErrorHandling: + def test_invalid_types_in_payload(self, api_client, sample_request_data): + # Set input to an integer instead of a string + sample_request_data["input"] = 12345 + response = api_client.post("/v1/responses", json=sample_request_data) + # Should return 422 Unprocessable Entity or 400 Bad Request + assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST] + + # Set model to a list instead of a string + sample_request_data["model"] = ["gpt-oss-20b"] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST] + + # Set tools to a string instead of a list + sample_request_data["tools"] = "browser_search" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST] def test_missing_required_fields(self, api_client): + """Test response when required fields are missing from the request.""" # Model field has default, so test with empty JSON response = api_client.post("/v1/responses", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_invalid_reasoning_effort(self, api_client, sample_request_data): + """Test response when reasoning_effort has an invalid value.""" sample_request_data["reasoning_effort"] = "invalid" response = api_client.post("/v1/responses", json=sample_request_data) # May handle gracefully or return error assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY] def test_malformed_json(self, api_client): + """Test response when the request body is not valid JSON.""" response = api_client.post( "/v1/responses", data="not json", @@ -131,6 +160,7 @@ def test_malformed_json(self, api_client): assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_extremely_long_input(self, api_client, sample_request_data): + """Test response when input is extremely long.""" # Test with very long input sample_request_data["input"] = "x" * 100000 response = api_client.post("/v1/responses", json=sample_request_data) @@ -140,6 +170,7 @@ def test_extremely_long_input(self, api_client, sample_request_data): class TestToolIntegration: def test_browser_search_tool(self, api_client, sample_request_data): + """Test response when browser_search tool is specified.""" sample_request_data["tools"] = [ { "type": "browser_search" @@ -149,6 +180,7 @@ def test_browser_search_tool(self, api_client, sample_request_data): assert response.status_code == status.HTTP_200_OK def test_function_tool_integration(self, api_client, sample_request_data): + """Test response when a function tool is specified.""" sample_request_data["tools"] = [ { "type": "function", @@ -161,6 +193,7 @@ def test_function_tool_integration(self, api_client, sample_request_data): assert response.status_code == status.HTTP_200_OK def test_multiple_tools(self, api_client, sample_request_data): + """Test response when multiple tools are specified.""" sample_request_data["tools"] = [ { "type": "browser_search" @@ -179,6 +212,7 @@ def test_multiple_tools(self, api_client, sample_request_data): class TestPerformance: def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer): + """Test that response time is under a specified threshold.""" performance_timer.start() response = api_client.post("/v1/responses", json=sample_request_data) elapsed = performance_timer.stop() @@ -188,6 +222,7 @@ def test_response_time_under_threshold(self, api_client, sample_request_data, pe assert elapsed < 5.0 # 5 seconds threshold def test_multiple_sequential_requests(self, api_client, sample_request_data): + """Test that multiple sequential requests are handled correctly.""" # Test multiple requests work correctly for i in range(3): data = sample_request_data.copy() @@ -199,6 +234,7 @@ def test_multiple_sequential_requests(self, api_client, sample_request_data): class TestUsageTracking: def test_usage_object_structure(self, api_client, sample_request_data): + """Test that the usage object in the response has the correct structure and values.""" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -217,6 +253,7 @@ def test_usage_object_structure(self, api_client, sample_request_data): assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] def test_usage_increases_with_longer_input(self, api_client, sample_request_data): + """Test that input_tokens increases as the input length increases.""" # Short input response1 = api_client.post("/v1/responses", json=sample_request_data) usage1 = response1.json()["usage"]