Skip to content

Commit ae88ab3

Browse files
committed
fix tests
1 parent 5f63a8e commit ae88ab3

File tree

5 files changed

+97
-99
lines changed

5 files changed

+97
-99
lines changed

tests/integration/test_clusters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def test_cluster_events(w, env_or_skip):
2828
count += 1
2929
assert count > 0
3030

31-
32-
def test_ensure_cluster_is_running(w, env_or_skip):
33-
cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID")
34-
cc = ClustersClient(config=w)
35-
cc.ensure_cluster_is_running(cluster_id)
31+
# TODO: Re-enable this test after adding waiters to the SDK
32+
# def test_ensure_cluster_is_running(w, env_or_skip):
33+
# cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID")
34+
# cc = ClustersClient(config=w)
35+
# cc.ensure_cluster_is_running(cluster_id)
3636

3737

3838
# TODO: Re-enable this test after adding LRO support to the SDK

tests/integration/test_dbutils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from databricks.sdk.databricks.core import DatabricksError
88
from databricks.sdk.databricks.errors import NotFound
9+
from databricks.sdk.databricks.dbutils import RemoteDbUtils
910

1011

1112
def test_rest_dbfs_ls(w, env_or_skip):
@@ -15,13 +16,15 @@ def test_rest_dbfs_ls(w, env_or_skip):
1516

1617
assert len(x) > 1
1718

19+
# TODO: Re-enable this test after adding waiters to the SDK
20+
# def test_proxy_dbfs_mounts(w, env_or_skip):
21+
22+
# w.cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID")
1823

19-
def test_proxy_dbfs_mounts(w, env_or_skip):
20-
w.config.cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID")
24+
# dbu = RemoteDbUtils(config=w)
25+
# x = dbu.fs.mounts()
2126

22-
x = w.dbutils.fs.mounts()
23-
24-
assert len(x) > 1
27+
# assert len(x) > 1
2528

2629

2730
@pytest.fixture(params=["dbfs", "volumes"])
@@ -54,8 +57,9 @@ def test_large_put(fs_and_base_path):
5457
def test_put_local_path(w, random, tmp_path):
5558
to_write = random(1024 * 1024 * 2)
5659
tmp_path = tmp_path / "tmp_file"
57-
w.dbutils.fs.put(f"file:{tmp_path}", to_write, True)
58-
assert w.dbutils.fs.head(f"file:{tmp_path}", 1024 * 1024 * 2) == to_write
60+
dbu = RemoteDbUtils(config=w)
61+
dbu.fs.put(f"file:{tmp_path}", to_write, True)
62+
assert dbu.fs.head(f"file:{tmp_path}", 1024 * 1024 * 2) == to_write
5963

6064

6165
def test_cp_file(fs_and_base_path, random):
@@ -184,9 +188,12 @@ def test_secrets(w, random):
184188
logger = logging.getLogger("foo")
185189
logger.info(f"Before loading secret: {random_value}")
186190

187-
w.secrets.create_scope(random_scope)
188-
w.secrets.put_secret(random_scope, key_for_string, string_value=random_value)
189-
w.secrets.put_secret(
191+
from databricks.sdk.workspace.v2.client import SecretsClient
192+
193+
sc = SecretsClient(config=w)
194+
sc.create_scope(random_scope)
195+
sc.put_secret(random_scope, key_for_string, string_value=random_value)
196+
sc.put_secret(
190197
random_scope,
191198
key_for_bytes,
192199
bytes_value=base64.b64encode(random_value.encode()).decode(),

tests/integration/test_files.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from databricks.sdk.databricks.core import DatabricksError
11-
11+
from databricks.sdk.files.v2.client import DbfsClient
1212

1313
def test_local_io(random):
1414
if platform.system() == "Windows":
@@ -28,20 +28,19 @@ def test_local_io(random):
2828
def test_dbfs_io(w, random):
2929
dummy_file = f"/tmp/{random()}"
3030
to_write = random(1024 * 1024 * 1.5).encode()
31-
with w.dbfs.open(dummy_file, write=True) as f:
31+
dc = DbfsClient(config=w)
32+
with dc.open(dummy_file, write=True) as f:
3233
written = f.write(to_write)
3334
assert len(to_write) == written
3435

35-
f = w.dbfs.open(dummy_file, read=True)
36+
f = dc.open(dummy_file, read=True)
3637
from_dbfs = f.read()
3738
assert from_dbfs == to_write
3839
f.close()
3940

4041

4142
@pytest.fixture
4243
def junk(w, random):
43-
from databricks.sdk.files.v2.client import DbfsClient
44-
4544
dc = DbfsClient(config=w)
4645

4746
def inner(path: str, size=256) -> bytes:
@@ -56,8 +55,6 @@ def inner(path: str, size=256) -> bytes:
5655

5756
@pytest.fixture
5857
def ls(w):
59-
from databricks.sdk.files.v2.client import DbfsClient
60-
6158
dc = DbfsClient(config=w)
6259

6360
def inner(root: str, recursive=False) -> List[str]:
@@ -88,8 +85,6 @@ def test_cp_dbfs_folder_to_folder_non_recursive(w, random, junk, ls):
8885
junk(f"{root}/a/b/03")
8986
new_root = f"/tmp/{random()}"
9087

91-
from databricks.sdk.files.v2.client import DbfsClient
92-
9388
dc = DbfsClient(config=w)
9489

9590
dc.copy(root, new_root)
@@ -104,8 +99,6 @@ def test_cp_dbfs_folder_to_folder_recursive(w, random, junk, ls):
10499
junk(f"{root}/a/b/03")
105100
new_root = f"/tmp/{random()}"
106101

107-
from databricks.sdk.files.v2.client import DbfsClient
108-
109102
dc = DbfsClient(config=w)
110103

111104
dc.copy(root, new_root, recursive=True, overwrite=True)
@@ -120,8 +113,6 @@ def test_cp_dbfs_folder_to_existing_folder_recursive(w, random, junk, ls):
120113
junk(f"{root}/a/b/03")
121114
new_root = f"/tmp/{random()}"
122115

123-
from databricks.sdk.files.v2.client import DbfsClient
124-
125116
dc = DbfsClient(config=w)
126117

127118
dc.mkdirs(new_root)
@@ -136,8 +127,6 @@ def test_cp_dbfs_file_to_non_existing_location(w, random, junk):
136127
payload = junk(f"{root}/01")
137128
copy_destination = f"{root}/{random()}"
138129

139-
from databricks.sdk.files.v2.client import DbfsClient
140-
141130
dc = DbfsClient(config=w)
142131

143132
dc.copy(f"{root}/01", copy_destination)
@@ -150,8 +139,6 @@ def test_cp_dbfs_file_to_existing_folder(w, random, junk):
150139
root = f"/tmp/{random()}"
151140
payload = junk(f"{root}/01")
152141

153-
from databricks.sdk.files.v2.client import DbfsClient
154-
155142
dc = DbfsClient(config=w)
156143

157144
dc.mkdirs(f"{root}/02")
@@ -166,8 +153,6 @@ def test_cp_dbfs_file_to_existing_location(w, random, junk):
166153
junk(f"{root}/01")
167154
junk(f"{root}/02")
168155

169-
from databricks.sdk.files.v2.client import DbfsClient
170-
171156
dc = DbfsClient(config=w)
172157

173158
with pytest.raises(DatabricksError) as ei:
@@ -180,8 +165,6 @@ def test_cp_dbfs_file_to_existing_location_with_overwrite(w, random, junk):
180165
payload = junk(f"{root}/01")
181166
junk(f"{root}/02")
182167

183-
from databricks.sdk.files.v2.client import DbfsClient
184-
185168
dc = DbfsClient(config=w)
186169

187170
dc.copy(f"{root}/01", f"{root}/02", overwrite=True)
@@ -194,8 +177,6 @@ def test_move_within_dbfs(w, random, junk):
194177
root = f"/tmp/{random()}"
195178
payload = junk(f"{root}/01")
196179

197-
from databricks.sdk.files.v2.client import DbfsClient
198-
199180
dc = DbfsClient(config=w)
200181

201182
dc.move_(f"{root}/01", f"{root}/02")
@@ -211,8 +192,6 @@ def test_move_from_dbfs_to_local(w, random, junk, tmp_path):
211192
payload_02 = junk(f"{root}/a/02")
212193
payload_03 = junk(f"{root}/a/b/03")
213194

214-
from databricks.sdk.files.v2.client import DbfsClient
215-
216195
dc = DbfsClient(config=w)
217196

218197
dc.move_(root, f"file:{tmp_path}", recursive=True)
@@ -230,8 +209,6 @@ def test_dbfs_upload_download(w, random, junk, tmp_path):
230209
root = pathlib.Path(f"/tmp/{random()}")
231210

232211
f = io.BytesIO(b"some text data")
233-
from databricks.sdk.files.v2.client import DbfsClient
234-
235212
dc = DbfsClient(config=w)
236213

237214
dc.upload(f"{root}/01", f)

tests/integration/test_iam.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from databricks.sdk.databricks import errors
44
from databricks.sdk.databricks.core import DatabricksError
5-
from databricks.sdk.iam.v2.client import GroupsClient, UsersClient
5+
from databricks.sdk.iam.v2.client import AccountGroupsClient, AccountUsersClient, AccountServicePrincipalsClient
6+
from databricks.sdk.iam.v2.client import GroupsClient, UsersClient, ServicePrincipalsClient
7+
from databricks.sdk.databricks.core import ApiClient
68

79

810
def test_filtering_groups(w, random):
@@ -32,44 +34,36 @@ def test_scim_get_user_as_dict(w):
3234

3335

3436
@pytest.mark.parametrize(
35-
"path,call",
37+
"client_class,path,count",
3638
[
37-
("/api/2.0/preview/scim/v2/Users", lambda w: w.users.list(count=10)),
38-
("/api/2.0/preview/scim/v2/Groups", lambda w: w.groups.list(count=4)),
39-
(
40-
"/api/2.0/preview/scim/v2/ServicePrincipals",
41-
lambda w: w.service_principals.list(count=1),
42-
),
39+
(UsersClient, "/api/2.0/preview/scim/v2/Users", 10),
40+
(GroupsClient, "/api/2.0/preview/scim/v2/Groups", 40),
41+
(ServicePrincipalsClient, "/api/2.0/preview/scim/v2/ServicePrincipals", 10),
4342
],
4443
)
45-
def test_workspace_users_list_pagination(w, path, call):
46-
raw = w.api_client.do("GET", path)
44+
def test_workspace_users_list_pagination(w, client_class, path, count):
45+
client = client_class(config=w)
46+
api_client = ApiClient(cfg=w)
47+
raw = api_client.do("GET", path)
4748
total = raw["totalResults"]
48-
all = call(w)
49+
all = client.list(count=count)
4950
found = len(list(all))
5051
assert found == total
5152

5253

5354
@pytest.mark.parametrize(
54-
"path,call",
55+
"client_class,path,count",
5556
[
56-
(
57-
"/api/2.0/accounts/%s/scim/v2/Users",
58-
lambda a: a.users.list(count=3000),
59-
),
60-
(
61-
"/api/2.0/accounts/%s/scim/v2/Groups",
62-
lambda a: a.groups.list(count=5),
63-
),
64-
(
65-
"/api/2.0/accounts/%s/scim/v2/ServicePrincipals",
66-
lambda a: a.service_principals.list(count=1000),
67-
),
57+
(AccountUsersClient, "/api/2.0/accounts/%s/scim/v2/Users", 3000),
58+
(AccountGroupsClient, "/api/2.0/accounts/%s/scim/v2/Groups", 50),
59+
(AccountServicePrincipalsClient, "/api/2.0/accounts/%s/scim/v2/ServicePrincipals", 1000),
6860
],
6961
)
70-
def test_account_users_list_pagination(a, path, call):
71-
raw = a.api_client.do("GET", path.replace("%s", a.config.account_id))
62+
def test_account_users_list_pagination(a, client_class, path, count):
63+
client = client_class(config=a)
64+
api_client = ApiClient(cfg=a)
65+
raw = api_client.do("GET", path.replace("%s", a.account_id))
7266
total = raw["totalResults"]
73-
all = call(a)
67+
all = client.list(count=count)
7468
found = len(list(all))
7569
assert found == total

0 commit comments

Comments
 (0)