Skip to content

Commit a320c69

Browse files
committed
Add integration test for the new Files API client
1 parent e82e5f8 commit a320c69

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

databricks/sdk/mixins/files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,7 @@ def perform():
10901090
# a 503 or 500 response, then you need to resume the interrupted upload from where it left off.
10911091

10921092
# Let's follow that for all potentially retryable status codes.
1093+
# Together with the catch block below we replicate the logic in _retry_idempotent_operation().
10931094
if upload_response.status_code in self._RETRYABLE_STATUS_CODES:
10941095
if retry_count < self._config.multipart_upload_max_retries:
10951096
retry_count += 1
@@ -1100,7 +1101,7 @@ def perform():
11001101
retry_count = 0
11011102

11021103
except RequestException as e:
1103-
# Let's do the same for retryable network errors
1104+
# Let's do the same for retryable network errors.
11041105
if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries:
11051106
retry_count += 1
11061107
upload_response = retrieve_upload_status()

tests/integration/test_files.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import datetime
12
import io
23
import logging
34
import pathlib
45
import platform
6+
import re
57
import time
8+
from textwrap import dedent
69
from typing import Callable, List, Tuple, Union
710

811
import pytest
912

1013
from databricks.sdk.core import DatabricksError
14+
from databricks.sdk.errors.sdk import OperationFailed
1115
from databricks.sdk.service.catalog import VolumeType
1216

1317

@@ -382,3 +386,135 @@ def test_files_api_download_benchmark(ucws, files_api, random):
382386
)
383387
min_str = str(best[0]) + "kb" if best[0] else "None"
384388
logging.info("Fastest chunk size: %s in %f seconds", min_str, best[1])
389+
390+
391+
@pytest.mark.parametrize("is_serverless", [True, False], ids=["Classic", "Serverless"])
392+
@pytest.mark.parametrize("use_new_files_api_client", [True, False], ids=["Default client", "Experimental client"])
393+
def test_files_api_in_cluster(ucws, random, env_or_skip, is_serverless, use_new_files_api_client):
394+
from databricks.sdk.service import compute, jobs
395+
396+
databricks_sdk_pypi_package = "databricks-sdk"
397+
option_env_name = "DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT"
398+
399+
launcher_file_path = f"/home/{ucws.current_user.me().user_name}/test_launcher.py"
400+
401+
schema = "filesit-" + random()
402+
volume = "filesit-" + random()
403+
with ResourceWithCleanup.create_schema(ucws, "main", schema):
404+
with ResourceWithCleanup.create_volume(ucws, "main", schema, volume):
405+
406+
cloud_file_path = f"/Volumes/main/{schema}/{volume}/test-{random()}.txt"
407+
file_size = 100 * 1024 * 1024
408+
409+
if use_new_files_api_client:
410+
enable_new_files_api_env = f"os.environ['{option_env_name}'] = 'True'"
411+
expected_files_api_client_class = "FilesExt"
412+
else:
413+
enable_new_files_api_env = ""
414+
expected_files_api_client_class = "FilesAPI"
415+
416+
using_files_api_client_msg = "Using files API client: "
417+
418+
command = f"""
419+
from databricks.sdk import WorkspaceClient
420+
import io
421+
import os
422+
import hashlib
423+
import logging
424+
425+
logging.basicConfig(level=logging.DEBUG)
426+
427+
{enable_new_files_api_env}
428+
429+
file_size = {file_size}
430+
original_content = os.urandom(file_size)
431+
cloud_file_path = '{cloud_file_path}'
432+
433+
w = WorkspaceClient()
434+
print(f"Using SDK: {{w.config._product_info}}")
435+
436+
print(f"{using_files_api_client_msg}{{type(w.files).__name__}}")
437+
438+
w.files.upload(cloud_file_path, io.BytesIO(original_content), overwrite=True)
439+
print("Upload succeeded")
440+
441+
response = w.files.download(cloud_file_path)
442+
resulting_content = response.contents.read()
443+
print("Download succeeded")
444+
445+
def hash(data: bytes):
446+
sha256 = hashlib.sha256()
447+
sha256.update(data)
448+
return sha256.hexdigest()
449+
450+
if len(resulting_content) != len(original_content):
451+
raise ValueError(f"Content length does not match: expected {{len(original_content)}}, actual {{len(resulting_content)}}")
452+
453+
expected_hash = hash(original_content)
454+
actual_hash = hash(resulting_content)
455+
if actual_hash != expected_hash:
456+
raise ValueError(f"Content hash does not match: expected {{expected_hash}}, actual {{actual_hash}}")
457+
458+
print(f"Contents of size {{len(resulting_content)}} match")
459+
"""
460+
461+
with ucws.dbfs.open(launcher_file_path, write=True, overwrite=True) as f:
462+
f.write(dedent(command).encode())
463+
464+
if is_serverless:
465+
# If no job_cluster_key, existing_cluster_id, or new_cluster were specified in task definition,
466+
# then task will be executed using serverless compute.
467+
new_cluster_spec = None
468+
469+
# Library is specified in the environment
470+
env_key = "test_env"
471+
envs = [jobs.JobEnvironment(env_key, compute.Environment("test", [databricks_sdk_pypi_package]))]
472+
libs = []
473+
else:
474+
new_cluster_spec = compute.ClusterSpec(
475+
spark_version=ucws.clusters.select_spark_version(long_term_support=True),
476+
instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"),
477+
num_workers=1,
478+
)
479+
480+
# Library is specified in the task definition
481+
env_key = None
482+
envs = []
483+
libs = [compute.Library(pypi=compute.PythonPyPiLibrary(package=databricks_sdk_pypi_package))]
484+
485+
waiter = ucws.jobs.submit(
486+
run_name=f"py-sdk-{random(8)}",
487+
tasks=[
488+
jobs.SubmitTask(
489+
task_key="task1",
490+
new_cluster=new_cluster_spec,
491+
spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{launcher_file_path}"),
492+
libraries=libs,
493+
environment_key=env_key,
494+
)
495+
],
496+
environments=envs,
497+
)
498+
499+
def print_status(r: jobs.Run):
500+
statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in r.tasks]
501+
logging.info(f'Run status: {", ".join(statuses)}')
502+
503+
logging.info(f"Waiting for the job run: {waiter.run_id}")
504+
try:
505+
job_run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status)
506+
task_run_id = job_run.tasks[0].run_id
507+
task_run_logs = ucws.jobs.get_run_output(task_run_id).logs
508+
logging.info(f"Run finished, output: {task_run_logs}")
509+
match = re.search(f"{using_files_api_client_msg}(.*)$", task_run_logs, re.MULTILINE)
510+
assert match is not None
511+
files_api_client_class = match.group(1)
512+
assert files_api_client_class == expected_files_api_client_class
513+
514+
except OperationFailed:
515+
job_run = ucws.jobs.get_run(waiter.run_id)
516+
task_run_id = job_run.tasks[0].run_id
517+
task_run_logs = ucws.jobs.get_run_output(task_run_id)
518+
raise ValueError(
519+
f"Run failed, error: {task_run_logs.error}, error trace: {task_run_logs.error_trace}, output: {task_run_logs.logs}"
520+
)

0 commit comments

Comments
 (0)