@@ -2393,6 +2393,69 @@ def test_execution_progress_handler(
23932393 )
23942394 self .stopSession (mock_session_controller_client_instance , session )
23952395
2396+ @mock .patch ("time.sleep" , return_value = None )
2397+ @mock .patch ("google.cloud.dataproc_v1.SessionControllerClient" )
2398+ def test_wait_for_session_available_success (
2399+ self , mock_session_controller_client , mock_sleep
2400+ ):
2401+ """Test that the method waits and returns the session when the endpoint appears."""
2402+ mock_client = mock_session_controller_client .return_value
2403+ session_name = "projects/test-project/locations/test-region/sessions/test-session"
2404+
2405+ # Session without the endpoint
2406+ session_pending = Session ()
2407+ session_pending .name = session_name
2408+
2409+ # Session with the endpoint
2410+ session_ready = Session ()
2411+ session_ready .name = session_name
2412+ session_ready .runtime_info .endpoints [
2413+ "Spark Connect Server"
2414+ ] = "sc://example.com:443"
2415+
2416+ # Mock get_session to return pending, then ready
2417+ mock_client .get_session .side_effect = [
2418+ session_pending ,
2419+ session_pending ,
2420+ session_ready ,
2421+ ]
2422+
2423+ builder = DataprocSparkSession .Builder ()
2424+ builder ._session_controller_client = mock_client # Inject the mock client
2425+
2426+ result = builder ._wait_for_session_available (session_name , timeout = 10 )
2427+
2428+ self .assertEqual (result , session_ready )
2429+ self .assertEqual (mock_client .get_session .call_count , 3 )
2430+ self .assertEqual (mock_sleep .call_count , 2 )
2431+
2432+ @mock .patch ("time.sleep" , return_value = None )
2433+ @mock .patch ("google.cloud.dataproc_v1.SessionControllerClient" )
2434+ def test_wait_for_session_available_timeout (
2435+ self , mock_session_controller_client , mock_sleep
2436+ ):
2437+ """Test that the method raises RuntimeError on timeout."""
2438+ mock_client = mock_session_controller_client .return_value
2439+ session_name = "projects/test-project/locations/test-region/sessions/test-session"
2440+
2441+ # Session that never gets the endpoint
2442+ session_pending = Session ()
2443+ session_pending .name = session_name
2444+
2445+ mock_client .get_session .return_value = session_pending
2446+
2447+ builder = DataprocSparkSession .Builder ()
2448+ builder ._session_controller_client = mock_client # Inject the mock client
2449+
2450+ with self .assertRaises (RuntimeError ) as context :
2451+ # Use a short timeout for the test
2452+ builder ._wait_for_session_available (session_name , timeout = 1 )
2453+
2454+ self .assertIn (
2455+ f"Spark Connect endpoint not available for session { session_name } " ,
2456+ str (context .exception ),
2457+ )
2458+
23962459
23972460class SessionIdValidationTests (unittest .TestCase ):
23982461 """Test cases for session ID validation and custom session ID functionality."""
0 commit comments