Skip to content

Commit ee6e70a

Browse files
[Internal] Reformat SDK with YAPF 0.43. (#822)
## What changes are proposed in this pull request? This PR is a no-op that reformats the SDK with the new version of `yapf` (0.43.0) which changed some formatting rules. This fixes a current issue with the `fmt` CI test which is failing on the `main`'s head. ## How is this tested? N/A
1 parent 271502b commit ee6e70a

File tree

7 files changed

+70
-55
lines changed

7 files changed

+70
-55
lines changed

databricks/sdk/config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,16 @@ class Config:
9292
max_connections_per_pool: int = ConfigAttribute()
9393
databricks_environment: Optional[DatabricksEnvironment] = None
9494

95-
def __init__(self,
96-
*,
97-
# Deprecated. Use credentials_strategy instead.
98-
credentials_provider: Optional[CredentialsStrategy] = None,
99-
credentials_strategy: Optional[CredentialsStrategy] = None,
100-
product=None,
101-
product_version=None,
102-
clock: Optional[Clock] = None,
103-
**kwargs):
95+
def __init__(
96+
self,
97+
*,
98+
# Deprecated. Use credentials_strategy instead.
99+
credentials_provider: Optional[CredentialsStrategy] = None,
100+
credentials_strategy: Optional[CredentialsStrategy] = None,
101+
product=None,
102+
product_version=None,
103+
clock: Optional[Clock] = None,
104+
**kwargs):
104105
self._header_factory = None
105106
self._inner = {}
106107
self._user_agent_other_info = []

databricks/sdk/credentials_provider.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,12 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
304304
# detect Azure AD Tenant ID if it's not specified directly
305305
token_endpoint = cfg.oidc_endpoints.token_endpoint
306306
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
307-
inner = ClientCredentials(client_id=cfg.azure_client_id,
308-
client_secret="", # we have no (rotatable) secrets in OIDC flow
309-
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
310-
endpoint_params=params,
311-
use_params=True)
307+
inner = ClientCredentials(
308+
client_id=cfg.azure_client_id,
309+
client_secret="", # we have no (rotatable) secrets in OIDC flow
310+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
311+
endpoint_params=params,
312+
use_params=True)
312313

313314
def refreshed_headers() -> Dict[str, str]:
314315
token = inner.token()

tests/integration/test_auth.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,16 @@ def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib
133133

134134
tasks = []
135135
for v in dbr_versions:
136-
t = Task(task_key=f'test_{v.key.replace(".", "_")}',
137-
notebook_task=NotebookTask(notebook_path=notebook_path),
138-
new_cluster=ClusterSpec(
139-
spark_version=v.key,
140-
num_workers=1,
141-
instance_pool_id=instance_pool_id,
142-
# GCP uses "custom" data security mode by default, which does not support UC.
143-
data_security_mode=DataSecurityMode.SINGLE_USER),
144-
libraries=[library])
136+
t = Task(
137+
task_key=f'test_{v.key.replace(".", "_")}',
138+
notebook_task=NotebookTask(notebook_path=notebook_path),
139+
new_cluster=ClusterSpec(
140+
spark_version=v.key,
141+
num_workers=1,
142+
instance_pool_id=instance_pool_id,
143+
# GCP uses "custom" data security mode by default, which does not support UC.
144+
data_security_mode=DataSecurityMode.SINGLE_USER),
145+
libraries=[library])
145146
tasks.append(t)
146147

147148
waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks)

tests/integration/test_jobs.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@ def test_submitting_jobs(w, random, env_or_skip):
1717
with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f:
1818
f.write(b'import time; time.sleep(10); print("Hello, World!")')
1919

20-
waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}',
21-
tasks=[
22-
jobs.SubmitTask(
23-
task_key='pi',
24-
new_cluster=compute.ClusterSpec(
25-
spark_version=w.clusters.select_spark_version(long_term_support=True),
26-
# node_type_id=w.clusters.select_node_type(local_disk=True),
27-
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
28-
num_workers=1),
29-
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
30-
)
31-
])
20+
waiter = w.jobs.submit(
21+
run_name=f'py-sdk-{random(8)}',
22+
tasks=[
23+
jobs.SubmitTask(
24+
task_key='pi',
25+
new_cluster=compute.ClusterSpec(
26+
spark_version=w.clusters.select_spark_version(long_term_support=True),
27+
# node_type_id=w.clusters.select_node_type(local_disk=True),
28+
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
29+
num_workers=1),
30+
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
31+
)
32+
])
3233

3334
logging.info(f'starting to poll: {waiter.run_id}')
3435

tests/test_base_client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,13 @@ def inner(h: BaseHTTPRequestHandler):
280280
assert len(requests) == 2
281281

282282

283-
@pytest.mark.parametrize('chunk_size,expected_chunks,data_size',
284-
[(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
285-
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
286-
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
287-
])
283+
@pytest.mark.parametrize(
284+
'chunk_size,expected_chunks,data_size',
285+
[
286+
(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
287+
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
288+
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
289+
])
288290
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
289291
rng = random.Random(42)
290292
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))

tests/test_core.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,20 @@ def inner(h: BaseHTTPRequestHandler):
370370
assert {'Authorization': 'Taker this-is-it'} == headers
371371

372372

373-
@pytest.mark.parametrize(['azure_environment', 'expected'],
374-
[('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
375-
('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']),
376-
('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']),
377-
# Kept for historical compatibility
378-
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
379-
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
380-
('AzureChinaCloud', ENVIRONMENTS['CHINA']), ])
373+
@pytest.mark.parametrize(
374+
['azure_environment', 'expected'],
375+
[
376+
('PUBLIC', ENVIRONMENTS['PUBLIC']),
377+
('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
378+
('CHINA', ENVIRONMENTS['CHINA']),
379+
('public', ENVIRONMENTS['PUBLIC']),
380+
('usgovernment', ENVIRONMENTS['USGOVERNMENT']),
381+
('china', ENVIRONMENTS['CHINA']),
382+
# Kept for historical compatibility
383+
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
384+
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
385+
('AzureChinaCloud', ENVIRONMENTS['CHINA']),
386+
])
381387
def test_azure_environment(azure_environment, expected):
382388
c = Config(credentials_strategy=noop_credentials,
383389
azure_workspace_resource_id='...',

tests/test_model_serving_auth.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
4747
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'
4848

4949

50-
@pytest.mark.parametrize("env_values, oauth_file_name", [
51-
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
52-
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
53-
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
54-
], "invalid_file_name"), # In Model Serving and Invalid File Name
55-
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
56-
])
50+
@pytest.mark.parametrize(
51+
"env_values, oauth_file_name",
52+
[
53+
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
54+
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')
55+
], "invalid_file_name"), # In Model Serving and Invalid File Name
56+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
57+
], "invalid_file_name"), # In Model Serving and Invalid File Name
58+
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
59+
])
5760
@raises(default_auth_base_error_message)
5861
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
5962
# Guarantee that the tests defaults to env variables rather than config file.

0 commit comments

Comments
 (0)