|
1 | | -import datetime |
2 | 1 | import io |
3 | 2 | import logging |
4 | 3 | import pathlib |
5 | 4 | import platform |
6 | | -import re |
7 | 5 | import time |
8 | | -from textwrap import dedent |
9 | 6 | from typing import Callable, List, Tuple, Union |
10 | 7 |
|
11 | 8 | import pytest |
12 | 9 |
|
13 | 10 | from databricks.sdk.core import DatabricksError |
14 | | -from databricks.sdk.errors.sdk import OperationFailed |
15 | 11 | from databricks.sdk.service.catalog import VolumeType |
16 | 12 |
|
17 | 13 |
|
@@ -386,135 +382,3 @@ def test_files_api_download_benchmark(ucws, files_api, random): |
386 | 382 | ) |
387 | 383 | min_str = str(best[0]) + "kb" if best[0] else "None" |
388 | 384 | logging.info("Fastest chunk size: %s in %f seconds", min_str, best[1]) |
389 | | - |
390 | | - |
391 | | -@pytest.mark.parametrize("is_serverless", [True, False], ids=["Classic", "Serverless"]) |
392 | | -@pytest.mark.parametrize("use_new_files_api_client", [True, False], ids=["Default client", "Experimental client"]) |
393 | | -def test_files_api_in_cluster(ucws, random, env_or_skip, is_serverless, use_new_files_api_client): |
394 | | - from databricks.sdk.service import compute, jobs |
395 | | - |
396 | | - databricks_sdk_pypi_package = "databricks-sdk" |
397 | | - option_env_name = "DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT" |
398 | | - |
399 | | - launcher_file_path = f"/home/{ucws.current_user.me().user_name}/test_launcher.py" |
400 | | - |
401 | | - schema = "filesit-" + random() |
402 | | - volume = "filesit-" + random() |
403 | | - with ResourceWithCleanup.create_schema(ucws, "main", schema): |
404 | | - with ResourceWithCleanup.create_volume(ucws, "main", schema, volume): |
405 | | - |
406 | | - cloud_file_path = f"/Volumes/main/{schema}/{volume}/test-{random()}.txt" |
407 | | - file_size = 100 * 1024 * 1024 |
408 | | - |
409 | | - if use_new_files_api_client: |
410 | | - enable_new_files_api_env = f"os.environ['{option_env_name}'] = 'True'" |
411 | | - expected_files_api_client_class = "FilesExt" |
412 | | - else: |
413 | | - enable_new_files_api_env = "" |
414 | | - expected_files_api_client_class = "FilesAPI" |
415 | | - |
416 | | - using_files_api_client_msg = "Using files API client: " |
417 | | - |
418 | | - command = f""" |
419 | | - from databricks.sdk import WorkspaceClient |
420 | | - import io |
421 | | - import os |
422 | | - import hashlib |
423 | | - import logging |
424 | | - |
425 | | - logging.basicConfig(level=logging.DEBUG) |
426 | | - |
427 | | - {enable_new_files_api_env} |
428 | | - |
429 | | - file_size = {file_size} |
430 | | - original_content = os.urandom(file_size) |
431 | | - cloud_file_path = '{cloud_file_path}' |
432 | | - |
433 | | - w = WorkspaceClient() |
434 | | - print(f"Using SDK: {{w.config._product_info}}") |
435 | | -
|
436 | | - print(f"{using_files_api_client_msg}{{type(w.files).__name__}}") |
437 | | - |
438 | | - w.files.upload(cloud_file_path, io.BytesIO(original_content), overwrite=True) |
439 | | - print("Upload succeeded") |
440 | | - |
441 | | - response = w.files.download(cloud_file_path) |
442 | | - resulting_content = response.contents.read() |
443 | | - print("Download succeeded") |
444 | | - |
445 | | - def hash(data: bytes): |
446 | | - sha256 = hashlib.sha256() |
447 | | - sha256.update(data) |
448 | | - return sha256.hexdigest() |
449 | | - |
450 | | - if len(resulting_content) != len(original_content): |
451 | | - raise ValueError(f"Content length does not match: expected {{len(original_content)}}, actual {{len(resulting_content)}}") |
452 | | - |
453 | | - expected_hash = hash(original_content) |
454 | | - actual_hash = hash(resulting_content) |
455 | | - if actual_hash != expected_hash: |
456 | | - raise ValueError(f"Content hash does not match: expected {{expected_hash}}, actual {{actual_hash}}") |
457 | | - |
458 | | - print(f"Contents of size {{len(resulting_content)}} match") |
459 | | - """ |
460 | | - |
461 | | - with ucws.dbfs.open(launcher_file_path, write=True, overwrite=True) as f: |
462 | | - f.write(dedent(command).encode()) |
463 | | - |
464 | | - if is_serverless: |
465 | | - # If no job_cluster_key, existing_cluster_id, or new_cluster were specified in task definition, |
466 | | - # then task will be executed using serverless compute. |
467 | | - new_cluster_spec = None |
468 | | - |
469 | | - # Library is specified in the environment |
470 | | - env_key = "test_env" |
471 | | - envs = [jobs.JobEnvironment(env_key, compute.Environment("test", [databricks_sdk_pypi_package]))] |
472 | | - libs = [] |
473 | | - else: |
474 | | - new_cluster_spec = compute.ClusterSpec( |
475 | | - spark_version=ucws.clusters.select_spark_version(long_term_support=True), |
476 | | - instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), |
477 | | - num_workers=1, |
478 | | - ) |
479 | | - |
480 | | - # Library is specified in the task definition |
481 | | - env_key = None |
482 | | - envs = [] |
483 | | - libs = [compute.Library(pypi=compute.PythonPyPiLibrary(package=databricks_sdk_pypi_package))] |
484 | | - |
485 | | - waiter = ucws.jobs.submit( |
486 | | - run_name=f"py-sdk-{random(8)}", |
487 | | - tasks=[ |
488 | | - jobs.SubmitTask( |
489 | | - task_key="task1", |
490 | | - new_cluster=new_cluster_spec, |
491 | | - spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{launcher_file_path}"), |
492 | | - libraries=libs, |
493 | | - environment_key=env_key, |
494 | | - ) |
495 | | - ], |
496 | | - environments=envs, |
497 | | - ) |
498 | | - |
499 | | - def print_status(r: jobs.Run): |
500 | | - statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in r.tasks] |
501 | | - logging.info(f'Run status: {", ".join(statuses)}') |
502 | | - |
503 | | - logging.info(f"Waiting for the job run: {waiter.run_id}") |
504 | | - try: |
505 | | - job_run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) |
506 | | - task_run_id = job_run.tasks[0].run_id |
507 | | - task_run_logs = ucws.jobs.get_run_output(task_run_id).logs |
508 | | - logging.info(f"Run finished, output: {task_run_logs}") |
509 | | - match = re.search(f"{using_files_api_client_msg}(.*)$", task_run_logs, re.MULTILINE) |
510 | | - assert match is not None |
511 | | - files_api_client_class = match.group(1) |
512 | | - assert files_api_client_class == expected_files_api_client_class |
513 | | - |
514 | | - except OperationFailed: |
515 | | - job_run = ucws.jobs.get_run(waiter.run_id) |
516 | | - task_run_id = job_run.tasks[0].run_id |
517 | | - task_run_logs = ucws.jobs.get_run_output(task_run_id) |
518 | | - raise ValueError( |
519 | | - f"Run failed, error: {task_run_logs.error}, error trace: {task_run_logs.error_trace}, output: {task_run_logs.logs}" |
520 | | - ) |
0 commit comments