Skip to content

Commit e2d8c4c

Browse files
committed
-
1 parent c391f79 commit e2d8c4c

File tree

9 files changed

+132
-113
lines changed

9 files changed

+132
-113
lines changed

databricks/sdk/config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,16 @@ class Config:
9292
max_connections_per_pool: int = ConfigAttribute()
9393
databricks_environment: Optional[DatabricksEnvironment] = None
9494

95-
def __init__(self,
96-
*,
97-
# Deprecated. Use credentials_strategy instead.
98-
credentials_provider: Optional[CredentialsStrategy] = None,
99-
credentials_strategy: Optional[CredentialsStrategy] = None,
100-
product=None,
101-
product_version=None,
102-
clock: Optional[Clock] = None,
103-
**kwargs):
95+
def __init__(
96+
self,
97+
*,
98+
# Deprecated. Use credentials_strategy instead.
99+
credentials_provider: Optional[CredentialsStrategy] = None,
100+
credentials_strategy: Optional[CredentialsStrategy] = None,
101+
product=None,
102+
product_version=None,
103+
clock: Optional[Clock] = None,
104+
**kwargs):
104105
self._header_factory = None
105106
self._inner = {}
106107
self._user_agent_other_info = []

databricks/sdk/credentials_provider.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,12 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
304304
# detect Azure AD Tenant ID if it's not specified directly
305305
token_endpoint = cfg.oidc_endpoints.token_endpoint
306306
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
307-
inner = ClientCredentials(client_id=cfg.azure_client_id,
308-
client_secret="", # we have no (rotatable) secrets in OIDC flow
309-
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
310-
endpoint_params=params,
311-
use_params=True)
307+
inner = ClientCredentials(
308+
client_id=cfg.azure_client_id,
309+
client_secret="", # we have no (rotatable) secrets in OIDC flow
310+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
311+
endpoint_params=params,
312+
use_params=True)
312313

313314
def refreshed_headers() -> Dict[str, str]:
314315
token = inner.token()

tests/integration/test_auth.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,16 @@ def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib
133133

134134
tasks = []
135135
for v in dbr_versions:
136-
t = Task(task_key=f'test_{v.key.replace(".", "_")}',
137-
notebook_task=NotebookTask(notebook_path=notebook_path),
138-
new_cluster=ClusterSpec(
139-
spark_version=v.key,
140-
num_workers=1,
141-
instance_pool_id=instance_pool_id,
142-
# GCP uses "custom" data security mode by default, which does not support UC.
143-
data_security_mode=DataSecurityMode.SINGLE_USER),
144-
libraries=[library])
136+
t = Task(
137+
task_key=f'test_{v.key.replace(".", "_")}',
138+
notebook_task=NotebookTask(notebook_path=notebook_path),
139+
new_cluster=ClusterSpec(
140+
spark_version=v.key,
141+
num_workers=1,
142+
instance_pool_id=instance_pool_id,
143+
# GCP uses "custom" data security mode by default, which does not support UC.
144+
data_security_mode=DataSecurityMode.SINGLE_USER),
145+
libraries=[library])
145146
tasks.append(t)
146147

147148
waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks)

tests/integration/test_jobs.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@ def test_submitting_jobs(w, random, env_or_skip):
1717
with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f:
1818
f.write(b'import time; time.sleep(10); print("Hello, World!")')
1919

20-
waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}',
21-
tasks=[
22-
jobs.SubmitTask(
23-
task_key='pi',
24-
new_cluster=compute.ClusterSpec(
25-
spark_version=w.clusters.select_spark_version(long_term_support=True),
26-
# node_type_id=w.clusters.select_node_type(local_disk=True),
27-
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
28-
num_workers=1),
29-
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
30-
)
31-
])
20+
waiter = w.jobs.submit(
21+
run_name=f'py-sdk-{random(8)}',
22+
tasks=[
23+
jobs.SubmitTask(
24+
task_key='pi',
25+
new_cluster=compute.ClusterSpec(
26+
spark_version=w.clusters.select_spark_version(long_term_support=True),
27+
# node_type_id=w.clusters.select_node_type(local_disk=True),
28+
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
29+
num_workers=1),
30+
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
31+
)
32+
])
3233

3334
logging.info(f'starting to poll: {waiter.run_id}')
3435

tests/test_base_client.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,13 @@ def inner(h: BaseHTTPRequestHandler):
281281
assert len(requests) == 2
282282

283283

284-
@pytest.mark.parametrize('chunk_size,expected_chunks,data_size',
285-
[(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
286-
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
287-
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
288-
])
284+
@pytest.mark.parametrize(
285+
'chunk_size,expected_chunks,data_size',
286+
[
287+
(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
288+
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
289+
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
290+
])
289291
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
290292
rng = random.Random(42)
291293
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))
@@ -355,12 +357,14 @@ def tell(self):
355357
assert client._is_seekable_stream(CustomSeekableStream())
356358

357359

358-
@pytest.mark.parametrize('input_data', [
359-
b"0123456789", # bytes -> BytesIO
360-
"0123456789", # str -> BytesIO
361-
io.BytesIO(b"0123456789"), # BytesIO directly
362-
io.StringIO("0123456789"), # StringIO
363-
])
360+
@pytest.mark.parametrize(
361+
'input_data',
362+
[
363+
b"0123456789", # bytes -> BytesIO
364+
"0123456789", # str -> BytesIO
365+
io.BytesIO(b"0123456789"), # BytesIO directly
366+
io.StringIO("0123456789"), # StringIO
367+
])
364368
def test_reset_seekable_stream_on_retry(input_data):
365369
received_data = []
366370

tests/test_core.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,20 @@ def inner(h: BaseHTTPRequestHandler):
370370
assert {'Authorization': 'Taker this-is-it'} == headers
371371

372372

373-
@pytest.mark.parametrize(['azure_environment', 'expected'],
374-
[('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
375-
('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']),
376-
('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']),
377-
# Kept for historical compatibility
378-
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
379-
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
380-
('AzureChinaCloud', ENVIRONMENTS['CHINA']), ])
373+
@pytest.mark.parametrize(
374+
['azure_environment', 'expected'],
375+
[
376+
('PUBLIC', ENVIRONMENTS['PUBLIC']),
377+
('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
378+
('CHINA', ENVIRONMENTS['CHINA']),
379+
('public', ENVIRONMENTS['PUBLIC']),
380+
('usgovernment', ENVIRONMENTS['USGOVERNMENT']),
381+
('china', ENVIRONMENTS['CHINA']),
382+
# Kept for historical compatibility
383+
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
384+
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
385+
('AzureChinaCloud', ENVIRONMENTS['CHINA']),
386+
])
381387
def test_azure_environment(azure_environment, expected):
382388
c = Config(credentials_strategy=noop_credentials,
383389
azure_workspace_resource_id='...',

tests/test_errors.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -83,52 +83,53 @@ def make_private_link_response() -> requests.Response:
8383
for x in base_subclass_test_cases]
8484

8585

86-
@pytest.mark.parametrize('response, expected_error, expected_message', subclass_test_cases + [
87-
(fake_response('GET', 400, ''), errors.BadRequest, 'Bad Request'),
88-
(fake_valid_response('GET', 417, 'WHOOPS', 'nope'), errors.DatabricksError, 'nope'),
89-
(fake_valid_response('GET', 522, '', 'nope'), errors.DatabricksError, 'nope'),
90-
(make_private_link_response(), errors.PrivateLinkValidationError,
91-
('The requested workspace has AWS PrivateLink enabled and is not accessible from the current network. '
92-
'Ensure that AWS PrivateLink is properly configured and that your device has access to the AWS VPC '
93-
'endpoint. For more information, see '
94-
'https://docs.databricks.com/en/security/network/classic/privatelink.html.'),
95-
),
96-
(fake_valid_response(
97-
'GET', 400, 'INVALID_PARAMETER_VALUE', 'Cluster abcde does not exist',
98-
'/api/2.0/clusters/get'), errors.ResourceDoesNotExist, 'Cluster abcde does not exist'),
99-
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Job abcde does not exist',
100-
'/api/2.0/jobs/get'), errors.ResourceDoesNotExist, 'Job abcde does not exist'),
101-
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Job abcde does not exist',
102-
'/api/2.1/jobs/get'), errors.ResourceDoesNotExist, 'Job abcde does not exist'),
103-
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Invalid spark version',
104-
'/api/2.1/jobs/get'), errors.InvalidParameterValue, 'Invalid spark version'),
105-
(fake_response(
106-
'GET', 400,
107-
'MALFORMED_REQUEST: vpc_endpoints malformed parameters: VPC Endpoint ... with use_case ... cannot be attached in ... list'
108-
), errors.BadRequest,
109-
'vpc_endpoints malformed parameters: VPC Endpoint ... with use_case ... cannot be attached in ... list'),
110-
(fake_response('GET', 400, '<pre>Worker environment not ready</pre>'), errors.BadRequest,
111-
'Worker environment not ready'),
112-
(fake_response('GET', 400, 'this is not a real response'), errors.BadRequest,
113-
('unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. '
114-
'Please report this issue with the following debugging information to the SDK issue tracker at '
115-
'https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n'
116-
'< 400 Bad Request\n'
117-
'< this is not a real response```')),
118-
(fake_response(
119-
'GET', 404,
120-
json.dumps({
121-
'detail': 'Group with id 1234 is not found',
122-
'status': '404',
123-
'schemas': ['urn:ietf:params:scim:api:messages:2.0:Error']
124-
})), errors.NotFound, 'None Group with id 1234 is not found'),
125-
(fake_response('GET', 404, json.dumps("This is JSON but not a dictionary")), errors.NotFound,
126-
'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< "This is JSON but not a dictionary"```'
127-
),
128-
(fake_raw_response('GET', 404, b'\x80'), errors.NotFound,
129-
'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< �```'
130-
)
131-
])
86+
@pytest.mark.parametrize(
87+
'response, expected_error, expected_message', subclass_test_cases +
88+
[(fake_response('GET', 400, ''), errors.BadRequest, 'Bad Request'),
89+
(fake_valid_response('GET', 417, 'WHOOPS', 'nope'), errors.DatabricksError, 'nope'),
90+
(fake_valid_response('GET', 522, '', 'nope'), errors.DatabricksError, 'nope'),
91+
(make_private_link_response(), errors.PrivateLinkValidationError,
92+
('The requested workspace has AWS PrivateLink enabled and is not accessible from the current network. '
93+
'Ensure that AWS PrivateLink is properly configured and that your device has access to the AWS VPC '
94+
'endpoint. For more information, see '
95+
'https://docs.databricks.com/en/security/network/classic/privatelink.html.'),
96+
),
97+
(fake_valid_response(
98+
'GET', 400, 'INVALID_PARAMETER_VALUE', 'Cluster abcde does not exist',
99+
'/api/2.0/clusters/get'), errors.ResourceDoesNotExist, 'Cluster abcde does not exist'),
100+
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Job abcde does not exist',
101+
'/api/2.0/jobs/get'), errors.ResourceDoesNotExist, 'Job abcde does not exist'),
102+
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Job abcde does not exist',
103+
'/api/2.1/jobs/get'), errors.ResourceDoesNotExist, 'Job abcde does not exist'),
104+
(fake_valid_response('GET', 400, 'INVALID_PARAMETER_VALUE', 'Invalid spark version',
105+
'/api/2.1/jobs/get'), errors.InvalidParameterValue, 'Invalid spark version'),
106+
(fake_response(
107+
'GET', 400,
108+
'MALFORMED_REQUEST: vpc_endpoints malformed parameters: VPC Endpoint ... with use_case ... cannot be attached in ... list'
109+
), errors.BadRequest,
110+
'vpc_endpoints malformed parameters: VPC Endpoint ... with use_case ... cannot be attached in ... list'
111+
),
112+
(fake_response('GET', 400, '<pre>Worker environment not ready</pre>'), errors.BadRequest,
113+
'Worker environment not ready'),
114+
(fake_response('GET', 400, 'this is not a real response'), errors.BadRequest,
115+
('unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. '
116+
'Please report this issue with the following debugging information to the SDK issue tracker at '
117+
'https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n'
118+
'< 400 Bad Request\n'
119+
'< this is not a real response```')),
120+
(fake_response(
121+
'GET', 404,
122+
json.dumps({
123+
'detail': 'Group with id 1234 is not found',
124+
'status': '404',
125+
'schemas': ['urn:ietf:params:scim:api:messages:2.0:Error']
126+
})), errors.NotFound, 'None Group with id 1234 is not found'),
127+
(fake_response('GET', 404, json.dumps("This is JSON but not a dictionary")), errors.NotFound,
128+
'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< "This is JSON but not a dictionary"```'
129+
),
130+
(fake_raw_response('GET', 404, b'\x80'), errors.NotFound,
131+
'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< �```'
132+
)])
132133
def test_get_api_error(response, expected_error, expected_message):
133134
parser = errors._Parser()
134135
with pytest.raises(errors.DatabricksError) as e:

tests/test_model_serving_auth.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
4747
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'
4848

4949

50-
@pytest.mark.parametrize("env_values, oauth_file_name", [
51-
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
52-
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
53-
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
54-
], "invalid_file_name"), # In Model Serving and Invalid File Name
55-
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
56-
])
50+
@pytest.mark.parametrize(
51+
"env_values, oauth_file_name",
52+
[
53+
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
54+
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')
55+
], "invalid_file_name"), # In Model Serving and Invalid File Name
56+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
57+
], "invalid_file_name"), # In Model Serving and Invalid File Name
58+
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
59+
])
5760
@raises(default_auth_base_error_message)
5861
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
5962
# Guarantee that the tests defaults to env variables rather than config file.

tests/test_oauth.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,26 @@ def test_token_cache_unique_filename_by_host():
1010
common_args = dict(client_id="abc",
1111
redirect_url="http://localhost:8020",
1212
oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
13-
assert TokenCache(host="http://localhost:", **common_args).filename != TokenCache(
14-
"https://bar.cloud.databricks.com", **common_args).filename
13+
assert TokenCache(host="http://localhost:",
14+
**common_args).filename != TokenCache("https://bar.cloud.databricks.com",
15+
**common_args).filename
1516

1617

1718
def test_token_cache_unique_filename_by_client_id():
1819
common_args = dict(host="http://localhost:",
1920
redirect_url="http://localhost:8020",
2021
oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
21-
assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", **
22-
common_args).filename
22+
assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def",
23+
**common_args).filename
2324

2425

2526
def test_token_cache_unique_filename_by_scopes():
2627
common_args = dict(host="http://localhost:",
2728
client_id="abc",
2829
redirect_url="http://localhost:8020",
2930
oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
30-
assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], **
31-
common_args).filename
31+
assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"],
32+
**common_args).filename
3233

3334

3435
def test_account_oidc_endpoints(requests_mock):

0 commit comments

Comments
 (0)