From 71f22fceaaf3932a02630ae332fbf98100c6db26 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 26 May 2025 12:08:25 -0600 Subject: [PATCH 1/3] Add MCP Changes --- databricks/sdk/__init__.py | 7 ++++++- databricks/sdk/service/mcp.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 databricks/sdk/service/mcp.py diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index f75645d25..6345a89e4 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -132,7 +132,7 @@ VectorSearchIndexesAPI) from databricks.sdk.service.workspace import (GitCredentialsAPI, ReposAPI, SecretsAPI, WorkspaceAPI) - +from databricks.sdk.service.mcp import MCP _LOG = logging.getLogger(__name__) @@ -337,6 +337,7 @@ def __init__( self._workspace_bindings = pkg_catalog.WorkspaceBindingsAPI(self._api_client) self._workspace_conf = pkg_settings.WorkspaceConfAPI(self._api_client) self._forecasting = pkg_ml.ForecastingAPI(self._api_client) + self._mcp = MCP(self.config) @property def config(self) -> client.Config: @@ -860,6 +861,10 @@ def forecasting(self) -> pkg_ml.ForecastingAPI: """The Forecasting API allows you to create and get serverless forecasting experiments.""" return self._forecasting + @property + def mcp(self) -> MCP: + return self._mcp + def get_workspace_id(self) -> int: """Get the workspace ID of the workspace that this client is connected to.""" response = self._api_client.do("GET", "/api/2.0/preview/scim/v2/Me", response_headers=["X-Databricks-Org-Id"]) diff --git a/databricks/sdk/service/mcp.py b/databricks/sdk/service/mcp.py new file mode 100644 index 000000000..4eb71a46a --- /dev/null +++ b/databricks/sdk/service/mcp.py @@ -0,0 +1,25 @@ +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthToken + +class DatabricksTokenStorage(TokenStorage): + def __init__(self, config): + self.config = config + + async def get_tokens(self) -> OAuthToken| None: + headers = self.config.authenticate() + token = headers["Authorization"].split("Bearer ")[1] + return OAuthToken(access_token=token, expires_in=60) + +class MCP: + def __init__(self, config): + self._config = config + self.databricks_token_storage = DatabricksTokenStorage(config) + + def oauth_provider(self): + return OAuthClientProvider( + server_url="", + client_metadata=None, + storage=self.databricks_token_storage, + redirect_handler=None, + callback_handler=None, + ) \ No newline at end of file From c894e3efb8d2495931f2d8e254560f528fceb517 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 26 May 2025 12:43:22 -0600 Subject: [PATCH 2/3] Update pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 60c33f0e6..3d14393bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,9 @@ openai = [ 'langchain-openai; python_version > "3.7"', "httpx", ] +mcp = [ + "mcp>=1.9.1" +] [tool.setuptools.dynamic] version = { attr = "databricks.sdk.version.__version__" } From fd36b9f892a7dee99a1e212014da8149d70c5519 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Wed, 28 May 2025 10:42:57 -0600 Subject: [PATCH 3/3] Add tests --- databricks/sdk/__init__.py | 5 +++-- databricks/sdk/service/mcp.py | 10 ++++++---- pyproject.toml | 2 ++ tests/test_mcp.py | 20 ++++++++++++++++++++ 4 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 tests/test_mcp.py diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 6345a89e4..e1836b214 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -84,6 +84,7 @@ ProviderExchangeFiltersAPI, ProviderExchangesAPI, ProviderFilesAPI, ProviderListingsAPI, ProviderPersonalizationRequestsAPI, ProviderProviderAnalyticsDashboardsAPI, ProviderProvidersAPI) +from databricks.sdk.service.mcp import MCP from databricks.sdk.service.ml import (ExperimentsAPI, ForecastingAPI, ModelRegistryAPI) from databricks.sdk.service.oauth2 import (AccountFederationPolicyAPI, @@ -132,7 +133,7 @@ VectorSearchIndexesAPI) from databricks.sdk.service.workspace import (GitCredentialsAPI, ReposAPI, SecretsAPI, WorkspaceAPI) -from databricks.sdk.service.mcp import MCP + _LOG = logging.getLogger(__name__) @@ -861,7 +862,7 @@ def forecasting(self) -> pkg_ml.ForecastingAPI: """The Forecasting API allows you to create and get serverless forecasting experiments.""" return self._forecasting - @property + @property def mcp(self) -> MCP: return self._mcp diff --git a/databricks/sdk/service/mcp.py b/databricks/sdk/service/mcp.py index 4eb71a46a..d906e3487 100644 --- a/databricks/sdk/service/mcp.py +++ b/databricks/sdk/service/mcp.py @@ -1,25 +1,27 @@ from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import OAuthToken + class DatabricksTokenStorage(TokenStorage): def __init__(self, config): self.config = config - async def get_tokens(self) -> OAuthToken| None: + async def get_tokens(self) -> OAuthToken | None: headers = self.config.authenticate() token = headers["Authorization"].split("Bearer ")[1] return OAuthToken(access_token=token, expires_in=60) - + + class MCP: def __init__(self, config): self._config = config self.databricks_token_storage = DatabricksTokenStorage(config) - def oauth_provider(self): + def get_oauth_provider(self): return OAuthClientProvider( server_url="", client_metadata=None, storage=self.databricks_token_storage, redirect_handler=None, callback_handler=None, - ) \ No newline at end of file + ) diff --git a/pyproject.toml b/pyproject.toml index 3d14393bb..5e90e6c7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ dev = [ 'langchain-openai; python_version > "3.7"', "httpx", "build", # some integration tests depend on the databricks-sdk-py wheel + "mcp>=1.9.1", + "pytest-asyncio" ] notebook = [ "ipython>=8,<10", diff --git a/tests/test_mcp.py b/tests/test_mcp.py new file mode 100644 index 000000000..9178ef2a6 --- /dev/null +++ b/tests/test_mcp.py @@ -0,0 +1,20 @@ +import time + +import httpx +import pytest + + +@pytest.mark.asyncio +async def test_mcp_oauth_provider(monkeypatch): + monkeypatch.setattr(time, "time", lambda: 100) + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + + w = WorkspaceClient() + mcp_oauth_provider = w.mcp.get_oauth_provider() + + request = httpx.Request("GET", "https://example.com") + response = await anext(mcp_oauth_provider.async_auth_flow(request)) + assert response.headers["Authorization"] == "Bearer test_token"