Skip to content

Commit c801502

Browse files
committed
Fix tests
1 parent ce0475b commit c801502

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

tests/test_notebook_oauth.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
"""Tests for runtime OAuth authentication in notebook environments."""
22

33
import os
4+
import sys
5+
import types
46
from datetime import datetime, timedelta
5-
from unittest.mock import patch
7+
from typing import Dict
68

79
import pytest
810

911
from databricks.sdk import oauth
1012
from databricks.sdk.config import Config
11-
from databricks.sdk.credentials_provider import (DefaultCredentials,
13+
from databricks.sdk.credentials_provider import (CredentialsProvider,
14+
CredentialsStrategy,
15+
DefaultCredentials,
1216
runtime_oauth)
1317

1418

@@ -24,14 +28,26 @@ def mock_runtime_env(monkeypatch):
2428
@pytest.fixture
2529
def mock_runtime_native_auth():
2630
"""Mock the runtime_native_auth to return a valid credentials provider."""
27-
with patch("databricks.sdk.runtime.init_runtime_native_auth") as mock_auth:
31+
fake_runtime = types.ModuleType("databricks.sdk.runtime")
2832

33+
def fake_init_runtime_native_auth():
2934
def inner():
3035
return {"Authorization": "Bearer test-notebook-pat-token"}
3136

32-
mock_auth.return_value = ("https://test.cloud.databricks.com", inner)
33-
mock_auth.__name__ = "init_runtime_native_auth"
34-
yield mock_auth
37+
return "https://test.cloud.databricks.com", inner
38+
39+
def fake_init_runtime_legacy_auth():
40+
pass
41+
42+
def fake_init_runtime_repl_auth():
43+
pass
44+
45+
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
46+
fake_runtime.init_runtime_legacy_auth = fake_init_runtime_legacy_auth
47+
fake_runtime.init_runtime_repl_auth = fake_init_runtime_repl_auth
48+
49+
sys.modules["databricks.sdk.runtime"] = fake_runtime
50+
yield
3551

3652

3753
@pytest.fixture
@@ -48,6 +64,17 @@ def mock_pat_exchange(mocker):
4864
return mock_exchange
4965

5066

67+
class MockCredentialsStrategy(CredentialsStrategy):
68+
def auth_type(self) -> str:
69+
return "mock_credentials_strategy"
70+
71+
def __call__(self, cfg) -> CredentialsProvider:
72+
def credentials_provider() -> Dict[str, str]:
73+
return {"Authorization": "Bearer: no_token"}
74+
75+
return credentials_provider
76+
77+
5178
@pytest.mark.parametrize(
5279
"scopes,auth_details",
5380
[
@@ -61,8 +88,12 @@ def test_runtime_oauth_success_scenarios(
6188
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes, auth_details
6289
):
6390
"""Test runtime-oauth works correctly in various valid configurations."""
64-
cfg = Config(host="https://test.cloud.databricks.com", scopes=scopes, authorization_details=auth_details)
65-
91+
cfg = Config(
92+
host="https://test.cloud.databricks.com",
93+
scopes=scopes,
94+
authorization_details=auth_details,
95+
credentials_strategy=MockCredentialsStrategy(),
96+
)
6697
creds_provider = runtime_oauth(cfg)
6798

6899
assert creds_provider is not None

0 commit comments

Comments
 (0)