Skip to content

Commit 6d4901e

Browse files
committed
fix: pysparkvalue error
1 parent 21b6239 commit 6d4901e

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

google/cloud/dataproc_spark_connect/session.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,27 @@ def create_session_pbar():
472472
session_response, dataproc_config.name
473473
)
474474

475+
def _wait_for_session_available(
476+
self, session_name: str, timeout: int = 300
477+
) -> Session:
478+
start_time = time.time()
479+
while time.time() - start_time < timeout:
480+
try:
481+
session = self.session_controller_client.get_session(
482+
name=session_name
483+
)
484+
if "Spark Connect Server" in session.runtime_info.endpoints:
485+
return session
486+
time.sleep(5)
487+
except Exception as e:
488+
logger.warning(
489+
f"Error while polling for Spark Connect endpoint: {e}"
490+
)
491+
time.sleep(5)
492+
raise RuntimeError(
493+
f"Spark Connect endpoint not available for session {session_name} after {timeout} seconds."
494+
)
495+
475496
def _display_session_link_on_creation(self, session_id):
476497
session_url = f"https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}?project={self._project_id}"
477498
plain_message = f"Creating Dataproc Session: {session_url}"
@@ -537,6 +558,9 @@ def _get_exiting_active_session(
537558
)
538559
self._display_view_session_details_button(s8s_session_id)
539560
if session is None:
561+
session_response = self._wait_for_session_available(
562+
session_name
563+
)
540564
session = self.__create_spark_connect_session_from_s8s(
541565
session_response, session_name
542566
)

tests/unit/test_session.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,69 @@ def test_execution_progress_handler(
23932393
)
23942394
self.stopSession(mock_session_controller_client_instance, session)
23952395

2396+
@mock.patch("time.sleep", return_value=None)
2397+
@mock.patch("google.cloud.dataproc_v1.SessionControllerClient")
2398+
def test_wait_for_session_available_success(
2399+
self, mock_session_controller_client, mock_sleep
2400+
):
2401+
"""Test that the method waits and returns the session when the endpoint appears."""
2402+
mock_client = mock_session_controller_client.return_value
2403+
session_name = "projects/test-project/locations/test-region/sessions/test-session"
2404+
2405+
# Session without the endpoint
2406+
session_pending = Session()
2407+
session_pending.name = session_name
2408+
2409+
# Session with the endpoint
2410+
session_ready = Session()
2411+
session_ready.name = session_name
2412+
session_ready.runtime_info.endpoints[
2413+
"Spark Connect Server"
2414+
] = "sc://example.com:443"
2415+
2416+
# Mock get_session to return pending, then ready
2417+
mock_client.get_session.side_effect = [
2418+
session_pending,
2419+
session_pending,
2420+
session_ready,
2421+
]
2422+
2423+
builder = DataprocSparkSession.Builder()
2424+
builder._session_controller_client = mock_client # Inject the mock client
2425+
2426+
result = builder._wait_for_session_available(session_name, timeout=10)
2427+
2428+
self.assertEqual(result, session_ready)
2429+
self.assertEqual(mock_client.get_session.call_count, 3)
2430+
self.assertEqual(mock_sleep.call_count, 2)
2431+
2432+
@mock.patch("time.sleep", return_value=None)
2433+
@mock.patch("google.cloud.dataproc_v1.SessionControllerClient")
2434+
def test_wait_for_session_available_timeout(
2435+
self, mock_session_controller_client, mock_sleep
2436+
):
2437+
"""Test that the method raises RuntimeError on timeout."""
2438+
mock_client = mock_session_controller_client.return_value
2439+
session_name = "projects/test-project/locations/test-region/sessions/test-session"
2440+
2441+
# Session that never gets the endpoint
2442+
session_pending = Session()
2443+
session_pending.name = session_name
2444+
2445+
mock_client.get_session.return_value = session_pending
2446+
2447+
builder = DataprocSparkSession.Builder()
2448+
builder._session_controller_client = mock_client # Inject the mock client
2449+
2450+
with self.assertRaises(RuntimeError) as context:
2451+
# Use a short timeout for the test
2452+
builder._wait_for_session_available(session_name, timeout=1)
2453+
2454+
self.assertIn(
2455+
f"Spark Connect endpoint not available for session {session_name}",
2456+
str(context.exception),
2457+
)
2458+
23962459

23972460
class SessionIdValidationTests(unittest.TestCase):
23982461
"""Test cases for session ID validation and custom session ID functionality."""

0 commit comments

Comments
 (0)