diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 44acb7c0f2..895c7d153d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -40,7 +40,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: if not new_cluster.get("docker_image"): new_cluster["docker_image"] = {"url": container.image} if not new_cluster.get("spark_conf"): - new_cluster["spark_conf"] = custom["sparkConf"] + new_cluster["spark_conf"] = custom.get("sparkConf", {}) if not new_cluster.get("spark_env_vars"): new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} else: diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index ff3e1797b7..7198a4dec0 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -10,7 +10,7 @@ from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark -from flytekitplugins.spark.task import Databricks, new_spark_session +from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session from pyspark.sql import SparkSession import flytekit @@ -135,6 +135,46 @@ def my_databricks(a: int) -> int: assert my_databricks(a=3) == 3 +@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}]) +def test_databricks_v2(reset_spark_session, spark_conf): + databricks_conf = { + "name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 4, + "docker_image": {"url": "pingsutw/databricks:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + "spark_python_task": { + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", + "parameters": "ls", + }, + } + + databricks_instance = "account.cloud.databricks.com" + + @task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance=databricks_instance, + spark_conf=spark_conf, + ) + ) + def my_databricks(a: int) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return a + + assert my_databricks.task_config is not None + assert my_databricks.task_config.databricks_conf == databricks_conf + assert my_databricks.task_config.databricks_instance == databricks_instance + assert my_databricks.task_config.spark_conf == (spark_conf or {}) + assert my_databricks(a=3) == 3 + + def test_new_spark_session(): name = "SessionName" spark_conf = {"spark1": "1", "spark2": "2"}