Skip to content

Commit e5b1708

Browse files
feat: Add support for running Spark Connect client inside the Dataproc s8s batch (#150)
Co-authored-by: Zhiwei Lin <[email protected]>
1 parent 3cb5ec8 commit e5b1708

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

google/cloud/dataproc_spark_connect/environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def is_interactive_terminal():
6767
return is_interactive() and is_terminal()
6868

6969

70+
def is_dataproc_batch() -> bool:
71+
return os.getenv("DATAPROC_WORKLOAD_TYPE") == "batch"
72+
73+
7074
def get_client_environment_label() -> str:
7175
"""
7276
Map current environment to a standardized client label.

google/cloud/dataproc_spark_connect/session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,13 @@ def _get_exiting_active_session(
576576

577577
def getOrCreate(self) -> "DataprocSparkSession":
578578
with DataprocSparkSession._lock:
579+
if environment.is_dataproc_batch():
580+
# For Dataproc batch workloads, connect to the already initialized local SparkSession
581+
from pyspark.sql import SparkSession as PySparkSQLSession
582+
583+
session = PySparkSQLSession.builder.getOrCreate()
584+
return session # type: ignore
585+
579586
# Handle custom session ID by setting it early and letting existing logic handle it
580587
if self._custom_session_id:
581588
self._handle_custom_session_id()

tests/integration/test_session.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,3 +664,40 @@ def test_sparksql_magic_with_dataproc_session(connect_session):
664664
assert row["multiplication"] == 50
665665
assert row["square_root"] == 4.0
666666
assert row["joined_string"] == "Dataproc-Spark"
667+
668+
669+
@pytest.fixture
670+
def batch_workload_env(monkeypatch):
671+
"""Sets DATAPROC_WORKLOAD_TYPE to 'batch' for a test."""
672+
monkeypatch.setenv("DATAPROC_WORKLOAD_TYPE", "batch")
673+
674+
675+
@pytest.fixture
676+
def local_spark_session():
677+
"""Provides a standard local PySpark session for comparison."""
678+
from pyspark.sql import SparkSession as PySparkSession
679+
680+
# Stop any existing session to ensure a clean environment for creating a local session.
681+
# This prevents test isolation failures where a Dataproc session from a previous
682+
# test might be picked up by getOrCreate().
683+
if DataprocSparkSession.getActiveSession():
684+
DataprocSparkSession.getActiveSession().stop()
685+
686+
session = PySparkSession.builder.master("local").getOrCreate()
687+
yield session
688+
session.stop()
689+
690+
691+
def test_create_local_spark_session(batch_workload_env, local_spark_session):
692+
"""Test creating a local Spark session."""
693+
from pyspark.sql import SparkSession as PySparkSession
694+
695+
dataproc_spark_session = DataprocSparkSession.builder.getOrCreate()
696+
try:
697+
assert isinstance(dataproc_spark_session, PySparkSession)
698+
assert not isinstance(dataproc_spark_session, DataprocSparkSession)
699+
700+
# Compare configurations to ensure they are both local sessions
701+
assert dataproc_spark_session == local_spark_session
702+
finally:
703+
dataproc_spark_session.stop()

0 commit comments

Comments
 (0)