|
| 1 | +import datetime |
1 | 2 | import io |
2 | 3 | import logging |
3 | 4 | import pathlib |
4 | 5 | import platform |
| 6 | +import re |
5 | 7 | import time |
| 8 | +from textwrap import dedent |
6 | 9 | from typing import Callable, List, Tuple, Union |
7 | 10 |
|
8 | 11 | import pytest |
9 | 12 |
|
10 | 13 | from databricks.sdk.core import DatabricksError |
| 14 | +from databricks.sdk.errors.sdk import OperationFailed |
11 | 15 | from databricks.sdk.service.catalog import VolumeType |
12 | 16 |
|
13 | 17 |
|
@@ -382,3 +386,135 @@ def test_files_api_download_benchmark(ucws, files_api, random): |
382 | 386 | ) |
383 | 387 | min_str = str(best[0]) + "kb" if best[0] else "None" |
384 | 388 | logging.info("Fastest chunk size: %s in %f seconds", min_str, best[1]) |
| 389 | + |
| 390 | + |
| 391 | +@pytest.mark.parametrize("is_serverless", [True, False], ids=["Classic", "Serverless"]) |
| 392 | +@pytest.mark.parametrize("use_new_files_api_client", [True, False], ids=["Default client", "Experimental client"]) |
| 393 | +def test_files_api_in_cluster(ucws, random, env_or_skip, is_serverless, use_new_files_api_client): |
| 394 | + from databricks.sdk.service import compute, jobs |
| 395 | + |
| 396 | + databricks_sdk_pypi_package = "databricks-sdk" |
| 397 | + option_env_name = "DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT" |
| 398 | + |
| 399 | + launcher_file_path = f"/home/{ucws.current_user.me().user_name}/test_launcher.py" |
| 400 | + |
| 401 | + schema = "filesit-" + random() |
| 402 | + volume = "filesit-" + random() |
| 403 | + with ResourceWithCleanup.create_schema(ucws, "main", schema): |
| 404 | + with ResourceWithCleanup.create_volume(ucws, "main", schema, volume): |
| 405 | + |
| 406 | + cloud_file_path = f"/Volumes/main/{schema}/{volume}/test-{random()}.txt" |
| 407 | + file_size = 100 * 1024 * 1024 |
| 408 | + |
| 409 | + if use_new_files_api_client: |
| 410 | + enable_new_files_api_env = f"os.environ['{option_env_name}'] = 'True'" |
| 411 | + expected_files_api_client_class = "FilesExt" |
| 412 | + else: |
| 413 | + enable_new_files_api_env = "" |
| 414 | + expected_files_api_client_class = "FilesAPI" |
| 415 | + |
| 416 | + using_files_api_client_msg = "Using files API client: " |
| 417 | + |
| 418 | + command = f""" |
| 419 | + from databricks.sdk import WorkspaceClient |
| 420 | + import io |
| 421 | + import os |
| 422 | + import hashlib |
| 423 | + import logging |
| 424 | + |
| 425 | + logging.basicConfig(level=logging.DEBUG) |
| 426 | + |
| 427 | + {enable_new_files_api_env} |
| 428 | + |
| 429 | + file_size = {file_size} |
| 430 | + original_content = os.urandom(file_size) |
| 431 | + cloud_file_path = '{cloud_file_path}' |
| 432 | + |
| 433 | + w = WorkspaceClient() |
| 434 | + print(f"Using SDK: {{w.config._product_info}}") |
| 435 | +
|
| 436 | + print(f"{using_files_api_client_msg}{{type(w.files).__name__}}") |
| 437 | + |
| 438 | + w.files.upload(cloud_file_path, io.BytesIO(original_content), overwrite=True) |
| 439 | + print("Upload succeeded") |
| 440 | + |
| 441 | + response = w.files.download(cloud_file_path) |
| 442 | + resulting_content = response.contents.read() |
| 443 | + print("Download succeeded") |
| 444 | + |
| 445 | + def hash(data: bytes): |
| 446 | + sha256 = hashlib.sha256() |
| 447 | + sha256.update(data) |
| 448 | + return sha256.hexdigest() |
| 449 | + |
| 450 | + if len(resulting_content) != len(original_content): |
| 451 | + raise ValueError(f"Content length does not match: expected {{len(original_content)}}, actual {{len(resulting_content)}}") |
| 452 | + |
| 453 | + expected_hash = hash(original_content) |
| 454 | + actual_hash = hash(resulting_content) |
| 455 | + if actual_hash != expected_hash: |
| 456 | + raise ValueError(f"Content hash does not match: expected {{expected_hash}}, actual {{actual_hash}}") |
| 457 | + |
| 458 | + print(f"Contents of size {{len(resulting_content)}} match") |
| 459 | + """ |
| 460 | + |
| 461 | + with ucws.dbfs.open(launcher_file_path, write=True, overwrite=True) as f: |
| 462 | + f.write(dedent(command).encode()) |
| 463 | + |
| 464 | + if is_serverless: |
| 465 | + # If no job_cluster_key, existing_cluster_id, or new_cluster were specified in task definition, |
| 466 | + # then task will be executed using serverless compute. |
| 467 | + new_cluster_spec = None |
| 468 | + |
| 469 | + # Library is specified in the environment |
| 470 | + env_key = "test_env" |
| 471 | + envs = [jobs.JobEnvironment(env_key, compute.Environment("test", [databricks_sdk_pypi_package]))] |
| 472 | + libs = [] |
| 473 | + else: |
| 474 | + new_cluster_spec = compute.ClusterSpec( |
| 475 | + spark_version=ucws.clusters.select_spark_version(long_term_support=True), |
| 476 | + instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), |
| 477 | + num_workers=1, |
| 478 | + ) |
| 479 | + |
| 480 | + # Library is specified in the task definition |
| 481 | + env_key = None |
| 482 | + envs = [] |
| 483 | + libs = [compute.Library(pypi=compute.PythonPyPiLibrary(package=databricks_sdk_pypi_package))] |
| 484 | + |
| 485 | + waiter = ucws.jobs.submit( |
| 486 | + run_name=f"py-sdk-{random(8)}", |
| 487 | + tasks=[ |
| 488 | + jobs.SubmitTask( |
| 489 | + task_key="task1", |
| 490 | + new_cluster=new_cluster_spec, |
| 491 | + spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{launcher_file_path}"), |
| 492 | + libraries=libs, |
| 493 | + environment_key=env_key, |
| 494 | + ) |
| 495 | + ], |
| 496 | + environments=envs, |
| 497 | + ) |
| 498 | + |
| 499 | + def print_status(r: jobs.Run): |
| 500 | + statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in r.tasks] |
| 501 | + logging.info(f'Run status: {", ".join(statuses)}') |
| 502 | + |
| 503 | + logging.info(f"Waiting for the job run: {waiter.run_id}") |
| 504 | + try: |
| 505 | + job_run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) |
| 506 | + task_run_id = job_run.tasks[0].run_id |
| 507 | + task_run_logs = ucws.jobs.get_run_output(task_run_id).logs |
| 508 | + logging.info(f"Run finished, output: {task_run_logs}") |
| 509 | + match = re.search(f"{using_files_api_client_msg}(.*)$", task_run_logs, re.MULTILINE) |
| 510 | + assert match is not None |
| 511 | + files_api_client_class = match.group(1) |
| 512 | + assert files_api_client_class == expected_files_api_client_class |
| 513 | + |
| 514 | + except OperationFailed: |
| 515 | + job_run = ucws.jobs.get_run(waiter.run_id) |
| 516 | + task_run_id = job_run.tasks[0].run_id |
| 517 | + task_run_logs = ucws.jobs.get_run_output(task_run_id) |
| 518 | + raise ValueError( |
| 519 | + f"Run failed, error: {task_run_logs.error}, error trace: {task_run_logs.error_trace}, output: {task_run_logs.logs}" |
| 520 | + ) |
0 commit comments