3232 TerminateSessionRequest ,
3333)
3434from pyspark .errors .exceptions import connect as connect_exceptions
35- from pyspark .sql import functions as F
3635from pyspark .sql .types import StringType
3736
3837
@@ -49,9 +48,28 @@ def test_project():
4948 return os .getenv ("GOOGLE_CLOUD_PROJECT" )
5049
5150
51+ def is_ci_environment ():
52+ """Detect if running in CI environment."""
53+ return os .getenv ("CI" ) == "true" or os .getenv ("GITHUB_ACTIONS" ) == "true"
54+
55+
5256@pytest .fixture
5357def auth_type (request ):
54- return getattr (request , "param" , "SERVICE_ACCOUNT" )
58+ """Auto-detect authentication type based on environment.
59+
60+ CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT
61+ Local environment: Uses END_USER_CREDENTIALS
62+ Test parametrization can still override this default.
63+ """
64+ # Allow test parametrization to override
65+ if hasattr (request , "param" ):
66+ return request .param
67+
68+ # Auto-detect based on environment
69+ if is_ci_environment ():
70+ return "SERVICE_ACCOUNT"
71+ else :
72+ return "END_USER_CREDENTIALS"
5573
5674
5775@pytest .fixture
@@ -113,23 +131,29 @@ def session_template_controller_client(test_client_options):
113131
114132@pytest .fixture
115133def connect_session (test_project , test_region , os_environment ):
116- return (
134+ session = (
117135 DataprocSparkSession .builder .projectId (test_project )
118136 .location (test_region )
119137 .getOrCreate ()
120138 )
139+ yield session
140+ # Clean up the session after each test to prevent resource conflicts
141+ try :
142+ session .stop ()
143+ except Exception :
144+ # Ignore cleanup errors to avoid masking the actual test failure
145+ pass
121146
122147
123148@pytest .fixture
124149def session_name (test_project , test_region , connect_session ):
125150 return f"projects/{ test_project } /locations/{ test_region } /sessions/{ DataprocSparkSession ._active_s8s_session_id } "
126151
127152
128- @pytest .mark .parametrize ("auth_type" , ["END_USER_CREDENTIALS" ], indirect = True )
129153def test_create_spark_session_with_default_notebook_behavior (
130154 auth_type , connect_session , session_name , session_controller_client
131155):
132- """Test creating a Spark session with default notebook behavior using end user credentials ."""
156+ """Test creating a Spark session with default notebook behavior using auto-detected authentication ."""
133157 get_session_request = GetSessionRequest ()
134158 get_session_request .name = session_name
135159 session = session_controller_client .get_session (get_session_request )
@@ -352,7 +376,10 @@ def test_create_spark_session_with_session_template_and_user_provided_dataproc_c
352376 reason = "Skipping PyPI package installation test since it's not supported yet"
353377)
354378def test_add_artifacts_pypi_package ():
355- """Test adding PyPI packages as artifacts to a Spark session."""
379+ """Test adding PyPI packages as artifacts to a Spark session.
380+
381+ Note: Skipped in CI due to infrastructure issues with PyPI package installation.
382+ """
356383 connect_session = DataprocSparkSession .builder .getOrCreate ()
357384 from pyspark .sql .connect .functions import udf , sum
358385 from pyspark .sql .types import IntegerType
@@ -380,35 +407,38 @@ def generate_random2(row) -> int:
380407
381408def test_sql_functions (connect_session ):
382409 """Test basic SQL functions like col(), sum(), count(), etc."""
410+ # Import SparkConnect-compatible functions
411+ from pyspark .sql .connect .functions import col , sum , count
412+
383413 # Create a test DataFrame
384414 df = connect_session .createDataFrame (
385415 [(1 , "Alice" , 100 ), (2 , "Bob" , 200 ), (3 , "Charlie" , 150 )],
386416 ["id" , "name" , "amount" ],
387417 )
388418
389419 # Test col() function
390- result_col = df .select (F . col ("name" )).collect ()
420+ result_col = df .select (col ("name" )).collect ()
391421 assert len (result_col ) == 3
392422 assert result_col [0 ]["name" ] == "Alice"
393423
394424 # Test aggregation functions
395- sum_result = df .select (F . sum ("amount" )).collect ()[0 ][0 ]
425+ sum_result = df .select (sum ("amount" )).collect ()[0 ][0 ]
396426 assert sum_result == 450
397427
398- count_result = df .select (F . count ("id" )).collect ()[0 ][0 ]
428+ count_result = df .select (count ("id" )).collect ()[0 ][0 ]
399429 assert count_result == 3
400430
401431 # Test with where clause using col()
402- filtered_df = df .where (F . col ("amount" ) > 150 )
432+ filtered_df = df .where (col ("amount" ) > 150 )
403433 filtered_count = filtered_df .count ()
404434 assert filtered_count == 1
405435
406436 # Test multiple column operations
407437 df_with_calc = df .select (
408- F . col ("id" ),
409- F . col ("name" ),
410- F . col ("amount" ),
411- (F . col ("amount" ) * 0.1 ).alias ("tax" ),
438+ col ("id" ),
439+ col ("name" ),
440+ col ("amount" ),
441+ (col ("amount" ) * 0.1 ).alias ("tax" ),
412442 )
413443 tax_results = df_with_calc .collect ()
414444 assert tax_results [0 ]["tax" ] == 10.0
@@ -418,6 +448,9 @@ def test_sql_functions(connect_session):
418448
419449def test_sql_udf (connect_session ):
420450 """Test SQL UDF registration and usage."""
451+ # Import SparkConnect-compatible functions
452+ from pyspark .sql .connect .functions import col , udf
453+
421454 # Create a test DataFrame
422455 df = connect_session .createDataFrame (
423456 [(1 , "hello" ), (2 , "world" ), (3 , "spark" )], ["id" , "text" ]
@@ -431,9 +464,9 @@ def uppercase_func(text):
431464 return text .upper () if text else None
432465
433466 # Test UDF with DataFrame API
434- uppercase_udf = F . udf (uppercase_func , StringType ())
467+ uppercase_udf = udf (uppercase_func , StringType ())
435468 df_with_udf = df .select (
436- "id" , "text" , uppercase_udf (F . col ("text" )).alias ("upper_text" )
469+ "id" , "text" , uppercase_udf (col ("text" )).alias ("upper_text" )
437470 )
438471 df_result = df_with_udf .collect ()
439472
@@ -444,7 +477,6 @@ def uppercase_func(text):
444477 connect_session .sql ("DROP VIEW IF EXISTS test_table" )
445478
446479
447- @pytest .mark .parametrize ("auth_type" , ["END_USER_CREDENTIALS" ], indirect = True )
448480def test_session_reuse_with_custom_id (
449481 auth_type ,
450482 test_project ,
@@ -453,7 +485,8 @@ def test_session_reuse_with_custom_id(
453485 os_environment ,
454486):
455487 """Test the real-world session reuse scenario: create → terminate → recreate with same ID."""
456- custom_session_id = "ml-pipeline-session"
488+ # Use a randomized session ID to avoid conflicts between test runs
489+ custom_session_id = f"ml-pipeline-session-{ uuid .uuid4 ().hex [:8 ]} "
457490
458491 # Stop any existing session first to ensure clean state
459492 if DataprocSparkSession ._active_s8s_session_id :
@@ -465,9 +498,12 @@ def test_session_reuse_with_custom_id(
465498 pass
466499
467500 # PHASE 1: Create initial session with custom ID
468- spark1 = DataprocSparkSession .builder .dataprocSessionId (
469- custom_session_id
470- ).getOrCreate ()
501+ spark1 = (
502+ DataprocSparkSession .builder .dataprocSessionId (custom_session_id )
503+ .projectId (test_project )
504+ .location (test_region )
505+ .getOrCreate ()
506+ )
471507
472508 # Verify session is created with custom ID
473509 assert DataprocSparkSession ._active_s8s_session_id == custom_session_id
@@ -482,9 +518,12 @@ def test_session_reuse_with_custom_id(
482518 # Clear cache to force session lookup
483519 DataprocSparkSession ._default_session = None
484520
485- spark2 = DataprocSparkSession .builder .dataprocSessionId (
486- custom_session_id
487- ).getOrCreate ()
521+ spark2 = (
522+ DataprocSparkSession .builder .dataprocSessionId (custom_session_id )
523+ .projectId (test_project )
524+ .location (test_region )
525+ .getOrCreate ()
526+ )
488527
489528 # Should reuse the same active session
490529 assert DataprocSparkSession ._active_s8s_session_id == custom_session_id
@@ -495,7 +534,7 @@ def test_session_reuse_with_custom_id(
495534 result2 = df2 .count ()
496535 assert result2 == 1
497536
498- # PHASE 3: Terminate session explicitly
537+ # PHASE 3: Stop should not terminate named session
499538 spark2 .stop ()
500539
501540 # PHASE 4: Recreate with same ID - this tests the cleanup and recreation logic
@@ -504,16 +543,19 @@ def test_session_reuse_with_custom_id(
504543 DataprocSparkSession ._active_s8s_session_id = None
505544 DataprocSparkSession ._active_s8s_session_uuid = None
506545
507- spark3 = DataprocSparkSession .builder .dataprocSessionId (
508- custom_session_id
509- ).getOrCreate ()
546+ spark3 = (
547+ DataprocSparkSession .builder .dataprocSessionId (custom_session_id )
548+ .projectId (test_project )
549+ .location (test_region )
550+ .getOrCreate ()
551+ )
510552
511- # Should be a new session with same ID but different UUID
553+ # Should be a same session and same ID
512554 assert DataprocSparkSession ._active_s8s_session_id == custom_session_id
513555 third_session_uuid = spark3 ._active_s8s_session_uuid
514556
515- # Should be different UUID (new session instance)
516- assert third_session_uuid ! = first_session_uuid
557+ # Should be same UUID
558+ assert third_session_uuid = = first_session_uuid
517559
518560 # Test functionality on recreated session
519561 df3 = spark3 .createDataFrame ([(3 , "recreated" )], ["id" , "stage" ])
@@ -546,7 +588,6 @@ def test_session_id_validation_in_integration(
546588 assert builder ._custom_session_id == valid_id
547589
548590
549- @pytest .mark .parametrize ("auth_type" , ["END_USER_CREDENTIALS" ], indirect = True )
550591def test_sparksql_magic_library_available (connect_session ):
551592 """Test that sparksql-magic library can be imported and loaded."""
552593 pytest .importorskip (
@@ -580,7 +621,6 @@ def test_sparksql_magic_library_available(connect_session):
580621 assert data [0 ]["test_column" ] == "integration_test"
581622
582623
583- @pytest .mark .parametrize ("auth_type" , ["END_USER_CREDENTIALS" ], indirect = True )
584624def test_sparksql_magic_with_dataproc_session (connect_session ):
585625 """Test that sparksql-magic works with registered DataprocSparkSession."""
586626 pytest .importorskip (
0 commit comments