Skip to content

Commit fd7c7ce

Browse files
committed
Add tests
1 parent d2a0175 commit fd7c7ce

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

databricks/sdk/__init__.py

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/service/mcp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
from mcp.client.auth import OAuthClientProvider, TokenStorage
22
from mcp.shared.auth import OAuthToken
33

4+
45
class DatabricksTokenStorage(TokenStorage):
56
def __init__(self, config):
67
self.config = config
78

8-
async def get_tokens(self) -> OAuthToken| None:
9+
async def get_tokens(self) -> OAuthToken | None:
910
headers = self.config.authenticate()
1011
token = headers["Authorization"].split("Bearer ")[1]
1112
return OAuthToken(access_token=token, expires_in=60)
12-
13+
14+
1315
class MCP:
1416
def __init__(self, config):
1517
self._config = config
1618
self.databricks_token_storage = DatabricksTokenStorage(config)
1719

18-
def oauth_provider(self):
20+
def get_oauth_provider(self):
1921
return OAuthClientProvider(
2022
server_url="",
2123
client_metadata=None,
2224
storage=self.databricks_token_storage,
2325
redirect_handler=None,
2426
callback_handler=None,
25-
)
27+
)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ dev = [
5353
'langchain-openai; python_version > "3.7"',
5454
"httpx",
5555
"build", # some integration tests depend on the databricks-sdk-py wheel
56+
"mcp>=1.9.1",
57+
"pytest-asyncio"
5658
]
5759
notebook = [
5860
"ipython>=8,<10",

tests/test_mcp.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import time
2+
3+
import httpx
4+
import pytest
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_mcp_oauth_provider(monkeypatch):
9+
monkeypatch.setattr(time, "time", lambda: 100)
10+
from databricks.sdk import WorkspaceClient
11+
12+
monkeypatch.setenv("DATABRICKS_HOST", "test_host")
13+
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
14+
15+
w = WorkspaceClient()
16+
mcp_oauth_provider = w.mcp.get_oauth_provider()
17+
18+
request = httpx.Request("GET", "https://example.com")
19+
response = await anext(mcp_oauth_provider.async_auth_flow(request))
20+
assert response.headers["Authorization"] == "Bearer test_token"

0 commit comments

Comments
 (0)