Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ class Config:
max_connections_per_pool: int = ConfigAttribute()
databricks_environment: Optional[DatabricksEnvironment] = None

def __init__(self,
*,
# Deprecated. Use credentials_strategy instead.
credentials_provider: Optional[CredentialsStrategy] = None,
credentials_strategy: Optional[CredentialsStrategy] = None,
product=None,
product_version=None,
clock: Optional[Clock] = None,
**kwargs):
def __init__(
self,
*,
# Deprecated. Use credentials_strategy instead.
credentials_provider: Optional[CredentialsStrategy] = None,
credentials_strategy: Optional[CredentialsStrategy] = None,
product=None,
product_version=None,
clock: Optional[Clock] = None,
**kwargs):
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
Expand Down
11 changes: 6 additions & 5 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,12 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
# detect Azure AD Tenant ID if it's not specified directly
token_endpoint = cfg.oidc_endpoints.token_endpoint
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
inner = ClientCredentials(client_id=cfg.azure_client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params=params,
use_params=True)
inner = ClientCredentials(
client_id=cfg.azure_client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params=params,
use_params=True)

def refreshed_headers() -> Dict[str, str]:
token = inner.token()
Expand Down
19 changes: 10 additions & 9 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@ def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib

tasks = []
for v in dbr_versions:
t = Task(task_key=f'test_{v.key.replace(".", "_")}',
notebook_task=NotebookTask(notebook_path=notebook_path),
new_cluster=ClusterSpec(
spark_version=v.key,
num_workers=1,
instance_pool_id=instance_pool_id,
# GCP uses "custom" data security mode by default, which does not support UC.
data_security_mode=DataSecurityMode.SINGLE_USER),
libraries=[library])
t = Task(
task_key=f'test_{v.key.replace(".", "_")}',
notebook_task=NotebookTask(notebook_path=notebook_path),
new_cluster=ClusterSpec(
spark_version=v.key,
num_workers=1,
instance_pool_id=instance_pool_id,
# GCP uses "custom" data security mode by default, which does not support UC.
data_security_mode=DataSecurityMode.SINGLE_USER),
libraries=[library])
tasks.append(t)

waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks)
Expand Down
25 changes: 13 additions & 12 deletions tests/integration/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ def test_submitting_jobs(w, random, env_or_skip):
with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f:
f.write(b'import time; time.sleep(10); print("Hello, World!")')

waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}',
tasks=[
jobs.SubmitTask(
task_key='pi',
new_cluster=compute.ClusterSpec(
spark_version=w.clusters.select_spark_version(long_term_support=True),
# node_type_id=w.clusters.select_node_type(local_disk=True),
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
num_workers=1),
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
)
])
waiter = w.jobs.submit(
run_name=f'py-sdk-{random(8)}',
tasks=[
jobs.SubmitTask(
task_key='pi',
new_cluster=compute.ClusterSpec(
spark_version=w.clusters.select_spark_version(long_term_support=True),
# node_type_id=w.clusters.select_node_type(local_disk=True),
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
num_workers=1),
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
)
])

logging.info(f'starting to poll: {waiter.run_id}')

Expand Down
12 changes: 7 additions & 5 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,13 @@ def inner(h: BaseHTTPRequestHandler):
assert len(requests) == 2


@pytest.mark.parametrize('chunk_size,expected_chunks,data_size',
[(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
])
@pytest.mark.parametrize(
'chunk_size,expected_chunks,data_size',
[
(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
])
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
rng = random.Random(42)
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))
Expand Down
22 changes: 14 additions & 8 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,20 @@ def inner(h: BaseHTTPRequestHandler):
assert {'Authorization': 'Taker this-is-it'} == headers


@pytest.mark.parametrize(['azure_environment', 'expected'],
[('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']),
('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']),
# Kept for historical compatibility
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
('AzureChinaCloud', ENVIRONMENTS['CHINA']), ])
@pytest.mark.parametrize(
['azure_environment', 'expected'],
[
('PUBLIC', ENVIRONMENTS['PUBLIC']),
('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
('CHINA', ENVIRONMENTS['CHINA']),
('public', ENVIRONMENTS['PUBLIC']),
('usgovernment', ENVIRONMENTS['USGOVERNMENT']),
('china', ENVIRONMENTS['CHINA']),
# Kept for historical compatibility
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
('AzureChinaCloud', ENVIRONMENTS['CHINA']),
])
def test_azure_environment(azure_environment, expected):
c = Config(credentials_strategy=noop_credentials,
azure_workspace_resource_id='...',
Expand Down
17 changes: 10 additions & 7 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'


@pytest.mark.parametrize("env_values, oauth_file_name", [
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
])
@pytest.mark.parametrize(
"env_values, oauth_file_name",
[
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
])
@raises(default_auth_base_error_message)
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
# Guarantee that the tests defaults to env variables rather than config file.
Expand Down
Loading