diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index c530be8a7..efc02e856 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,7 @@ ## Release v0.60.0 ### New Features and Improvements +* Added headers to HttpRequestResponse in OpenAI client. ### Bug Fixes diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 62835dcec..7d16f519d 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -4,6 +4,7 @@ from requests import Response from databricks.sdk.service.serving import (ExternalFunctionRequestHttpMethod, + HttpRequestResponse, ServingEndpointsAPI) @@ -88,15 +89,30 @@ def http_request( """ response = Response() response.status_code = 200 - server_response = super().http_request( - connection_name=conn, - method=method, - path=path, - headers=js.dumps(headers) if headers is not None else None, - json=js.dumps(json) if json is not None else None, - params=js.dumps(params) if params is not None else None, + + # We currently don't call super.http_request because we need to pass in response_headers + # This is a temporary fix to get the headers we need for the MCP session id + # TODO: Remove this once we have a better way to get back the response headers + headers_to_capture = ["mcp-session-id"] + res = self._api.do( + "POST", + "/api/2.0/external-function", + body={ + "connection_name": conn, + "method": method.value, + "path": path, + "headers": js.dumps(headers) if headers is not None else None, + "json": js.dumps(json) if json is not None else None, + "params": js.dumps(params) if params is not None else None, + }, + headers={"Accept": "text/plain", "Content-Type": "application/json"}, + raw=True, + response_headers=headers_to_capture, ) + # Create HttpRequestResponse from the raw response + server_response = HttpRequestResponse.from_dict(res) + # Read the content from the HttpRequestResponse object if hasattr(server_response, "contents") and hasattr(server_response.contents, "read"): raw_content = server_response.contents.read() # Read the bytes @@ -109,4 +125,9 @@ def http_request( else: raise ValueError("Contents must be bytes.") + # Copy headers from raw response to Response + for header_name in headers_to_capture: + if header_name in res: + response.headers[header_name] = res[header_name] + return response