Skip to content

Commit 4e2adb4

Browse files
authored
Merge branch 'main' into main
2 parents 92b261c + 7d22b4d commit 4e2adb4

File tree

4 files changed

+74
-16
lines changed

4 files changed

+74
-16
lines changed

tests/integration/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def a(env_or_skip) -> AccountClient:
5959
return account_client
6060

6161

62+
@pytest.fixture(scope='session')
63+
def ucacct(env_or_skip) -> AccountClient:
64+
_load_debug_env_if_runs_from_ide('ucacct')
65+
env_or_skip("CLOUD_ENV")
66+
account_client = AccountClient()
67+
if not account_client.config.is_account_client:
68+
pytest.skip("not Databricks Account client")
69+
if 'TEST_METASTORE_ID' not in os.environ:
70+
pytest.skip("not in Unity Catalog Workspace test env")
71+
return account_client
72+
73+
6274
@pytest.fixture(scope='session')
6375
def w(env_or_skip) -> WorkspaceClient:
6476
_load_debug_env_if_runs_from_ide('workspace')
@@ -104,6 +116,14 @@ def volume(ucws, schema):
104116
ucws.volumes.delete(volume.full_name)
105117

106118

119+
@pytest.fixture()
120+
def workspace_dir(w, random):
121+
directory = f'/Users/{w.current_user.me().user_name}/dir-{random(12)}'
122+
w.workspace.mkdirs(directory)
123+
yield directory
124+
w.workspace.delete(directory, recursive=True)
125+
126+
107127
def _load_debug_env_if_runs_from_ide(key) -> bool:
108128
if not _is_in_debug():
109129
return False

tests/integration/test_auth.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import shutil
66
import subprocess
77
import sys
8+
import typing
89
import urllib.parse
910
from functools import partial
1011
from pathlib import Path
1112

1213
import pytest
1314

1415
from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode,
15-
Library, ResultType)
16+
Library, ResultType, SparkVersion)
1617
from databricks.sdk.service.jobs import NotebookTask, Task, ViewType
1718
from databricks.sdk.service.workspace import ImportFormat
1819

@@ -84,19 +85,41 @@ def test_runtime_auth_from_interactive_on_uc(ucws, fresh_wheel_file, env_or_skip
8485
ucws.clusters.permanent_delete(interactive_cluster.cluster_id)
8586

8687

87-
def test_runtime_auth_from_jobs(w, fresh_wheel_file, env_or_skip, random):
88-
instance_pool_id = env_or_skip('TEST_INSTANCE_POOL_ID')
89-
88+
def _get_lts_versions(w) -> typing.List[SparkVersion]:
9089
v = w.clusters.spark_versions()
9190
lts_runtimes = [
9291
x for x in v.versions
9392
if 'LTS' in x.name and '-ml' not in x.key and '-photon' not in x.key and '-aarch64' not in x.key
9493
]
94+
return lts_runtimes
95+
96+
97+
def test_runtime_auth_from_jobs_volumes(ucws, fresh_wheel_file, env_or_skip, random, volume):
98+
dbr_versions = [v for v in _get_lts_versions(ucws) if int(v.key.split('.')[0]) >= 15]
99+
100+
volume_wheel = f'{volume}/tmp/wheels/{random(10)}/{fresh_wheel_file.name}'
101+
with fresh_wheel_file.open('rb') as f:
102+
ucws.files.upload(volume_wheel, f)
103+
104+
lib = Library(whl=volume_wheel)
105+
return _test_runtime_auth_from_jobs_inner(ucws, env_or_skip, random, dbr_versions, lib)
106+
107+
108+
def test_runtime_auth_from_jobs_dbfs(w, fresh_wheel_file, env_or_skip, random):
109+
# Library installation from DBFS is not supported past DBR 14.3
110+
dbr_versions = [v for v in _get_lts_versions(w) if int(v.key.split('.')[0]) < 15]
95111

96112
dbfs_wheel = f'/tmp/wheels/{random(10)}/{fresh_wheel_file.name}'
97113
with fresh_wheel_file.open('rb') as f:
98114
w.dbfs.upload(dbfs_wheel, f)
99115

116+
lib = Library(whl=f'dbfs:{dbfs_wheel}')
117+
return _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib)
118+
119+
120+
def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, library):
121+
instance_pool_id = env_or_skip('TEST_INSTANCE_POOL_ID')
122+
100123
my_name = w.current_user.me().user_name
101124
notebook_path = f'/Users/{my_name}/notebook-native-auth'
102125
notebook_content = io.BytesIO(b'''
@@ -109,16 +132,20 @@ def test_runtime_auth_from_jobs(w, fresh_wheel_file, env_or_skip, random):
109132
w.workspace.upload(notebook_path, notebook_content, language=Language.PYTHON, overwrite=True)
110133

111134
tasks = []
112-
for v in lts_runtimes:
135+
for v in dbr_versions:
113136
t = Task(task_key=f'test_{v.key.replace(".", "_")}',
114137
notebook_task=NotebookTask(notebook_path=notebook_path),
115-
new_cluster=ClusterSpec(spark_version=v.key,
116-
num_workers=1,
117-
instance_pool_id=instance_pool_id),
118-
libraries=[Library(whl=f'dbfs:{dbfs_wheel}')])
138+
new_cluster=ClusterSpec(
139+
spark_version=v.key,
140+
num_workers=1,
141+
instance_pool_id=instance_pool_id,
142+
# GCP uses "custom" data security mode by default, which does not support UC.
143+
data_security_mode=DataSecurityMode.SINGLE_USER),
144+
libraries=[library])
119145
tasks.append(t)
120146

121-
run = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks).result()
147+
waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks)
148+
run = waiter.result()
122149
for task_key, output in _task_outputs(w, run).items():
123150
assert my_name in output, f'{task_key} does not work with notebook native auth'
124151

tests/integration/test_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pytest
22

33

4-
def test_get_workspace_client(a, env_or_skip):
4+
def test_get_workspace_client(ucacct, env_or_skip):
5+
# Need to switch to ucacct
56
workspace_id = env_or_skip("TEST_WORKSPACE_ID")
6-
ws = a.workspaces.get(workspace_id)
7-
w = a.get_workspace_client(ws)
7+
ws = ucacct.workspaces.get(workspace_id)
8+
w = ucacct.get_workspace_client(ws)
89
assert w.current_user.me().active
910

1011

tests/integration/test_workspace.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,21 @@
33
from databricks.sdk.service.workspace import ImportFormat, Language
44

55

6-
def test_workspace_recursive_list(w, random):
6+
def test_workspace_recursive_list(w, workspace_dir, random):
7+
# create a file in the directory
8+
file = f'{workspace_dir}/file-{random(12)}.py'
9+
w.workspace.upload(file, io.BytesIO(b'print(1)'))
10+
# create a subdirectory
11+
subdirectory = f'{workspace_dir}/subdir-{random(12)}'
12+
w.workspace.mkdirs(subdirectory)
13+
# create a file in the subdirectory
14+
subfile = f'{subdirectory}/subfile-{random(12)}.py'
15+
w.workspace.upload(subfile, io.BytesIO(b'print(2)'))
16+
# list the directory recursively
717
names = []
8-
for i in w.workspace.list(f'/Users/{w.current_user.me().user_name}', recursive=True):
18+
for i in w.workspace.list(workspace_dir, recursive=True):
919
names.append(i.path)
10-
assert len(names) > 0
20+
assert len(names) == 2
1121

1222

1323
def test_workspace_upload_download_notebooks(w, random):

0 commit comments

Comments
 (0)