Skip to content

Commit 89fafd4

Browse files
Unit tests
1 parent a206173 commit 89fafd4

File tree

3 files changed

+118
-3
lines changed

3 files changed

+118
-3
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,12 @@ def env_oidc(cfg) -> Optional[CredentialsProvider]:
321321

322322
return _oidc_credentials_provider(cfg, oidc.EnvIdTokenSource(env_var))
323323

324+
324325
@credentials_strategy("file-oidc", ["host", "oidc_token_filepath"])
325326
def file_oidc(cfg) -> Optional[CredentialsProvider]:
326327
return _oidc_credentials_provider(cfg, oidc.FileIdTokenSource(cfg.oidc_token_filepath))
327328

329+
328330
# This function is a helper function to create an OIDC CredentialsProvider
329331
# that provides a Databricks token from an IdTokenSource.
330332
def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]:

tests/test_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77

88
import pytest
99

10-
from databricks.sdk import useragent
10+
from databricks.sdk import oauth, useragent
1111
from databricks.sdk.config import Config, with_product, with_user_agent_extra
12-
from databricks.sdk.credentials_provider import Token
1312
from databricks.sdk.version import __version__
1413

1514
from .conftest import noop_credentials, set_az_path
@@ -114,7 +113,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
114113
def test_config_deep_copy(monkeypatch, mocker, tmp_path):
115114
mocker.patch(
116115
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
117-
return_value=Token(
116+
return_value=oauth.Token(
118117
access_token="token",
119118
token_type="Bearer",
120119
expiry=datetime(2023, 5, 22, 0, 0, 0),

tests/test_oidc.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
import pytest
5+
6+
from databricks.sdk import oidc
7+
8+
9+
class MockIdTokenSource(oidc.IdTokenSource):
10+
def __init__(self, id_token: str, exception: Exception = None):
11+
self.id_token = id_token
12+
self.exception = exception
13+
14+
def id_token(self) -> oidc.IdToken:
15+
if self.exception:
16+
raise self.exception
17+
return oidc.IdToken(jwt=self.id_token)
18+
19+
20+
@dataclass
21+
class EnvTestCase:
22+
name: str
23+
env_name: str = ""
24+
env_value: str = ""
25+
want: oidc.IdToken = None
26+
wantException: Exception = None
27+
28+
29+
_env_id_test_cases = [
30+
EnvTestCase(
31+
name="success",
32+
env_name="OIDC_TEST_TOKEN_SUCCESS",
33+
env_value="test-token-123",
34+
want=oidc.IdToken(jwt="test-token-123"),
35+
),
36+
EnvTestCase(
37+
name="missing_env_var",
38+
env_name="OIDC_TEST_TOKEN_MISSING",
39+
env_value="",
40+
wantException=ValueError,
41+
),
42+
EnvTestCase(
43+
name="empty_env_var",
44+
env_name="OIDC_TEST_TOKEN_EMPTY",
45+
env_value="",
46+
wantException=ValueError,
47+
),
48+
EnvTestCase(
49+
name="different_variable_name",
50+
env_name="ANOTHER_OIDC_TOKEN",
51+
env_value="another-token-456",
52+
want=oidc.IdToken(jwt="another-token-456"),
53+
),
54+
]
55+
56+
57+
@pytest.mark.parametrize("test_case", _env_id_test_cases)
58+
def test_env_id_token_source(test_case: EnvIdTestCase, monkeypatch):
59+
monkeypatch.setenv(test_case.env_name, test_case.env_value)
60+
61+
source = oidc.EnvIdTokenSource(test_case.env_name)
62+
if test_case.wantException:
63+
with pytest.raises(test_case.wantException):
64+
source.id_token()
65+
else:
66+
assert source.id_token() == test_case.want
67+
68+
69+
@dataclass
70+
class FileTestCase:
71+
name: str
72+
file: Optional[Tuple[str, str]] = None # (name, content)
73+
filepath: str = ""
74+
want: oidc.IdToken = None
75+
wantException: Exception = None
76+
77+
78+
_file_id_test_cases = [
79+
FileTestCase(
80+
name="missing_filepath",
81+
file=("token", "content"),
82+
filepath="",
83+
wantException=ValueError,
84+
),
85+
FileTestCase(
86+
name="empty_file",
87+
file=("token", ""),
88+
filepath="token",
89+
wantException=ValueError,
90+
),
91+
FileTestCase(
92+
name="file_does_not_exist",
93+
),
94+
FileTestCase(
95+
name="file_exists",
96+
file=("token", "content"),
97+
filepath="token",
98+
want=oidc.IdToken(jwt="content"),
99+
),
100+
]
101+
102+
103+
@pytest.mark.parametrize("test_case", _file_id_test_cases)
104+
def test_file_id_token_source(test_case: FileTestCase, tmp_path):
105+
if test_case.file:
106+
token_file = tmp_path / test_case.file[0]
107+
token_file.write_text(test_case.file[1])
108+
109+
source = oidc.FileIdTokenSource(test_case.filepath)
110+
if test_case.wantException:
111+
with pytest.raises(test_case.wantException):
112+
source.id_token()
113+
else:
114+
assert source.id_token() == test_case.want

0 commit comments

Comments
 (0)