Skip to content

Commit 3cb5ec8

Browse files
authored
Added partial integration tests for pre-submit (#148)
1 parent 436d97e commit 3cb5ec8

File tree

3 files changed

+159
-33
lines changed

3 files changed

+159
-33
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Integration Tests with Service Account Authentication
16+
#
17+
# Required GitHub Secrets:
18+
# - GCP_SA_KEY: Service account JSON key (project_id and client_email extracted automatically)
19+
# - GCP_REGION: Google Cloud Region (optional, defaults to us-central1)
20+
# - GCP_SUBNET: Dataproc subnet URI
21+
#
22+
# See INTEGRATION_TESTS.md for setup instructions.
23+
24+
name: Integration Tests
25+
on:
26+
pull_request:
27+
branches: [ main ]
28+
workflow_dispatch:
29+
30+
jobs:
31+
integration-test:
32+
name: Run integration tests
33+
runs-on: ubuntu-latest
34+
35+
# Only run integration tests if secrets are available
36+
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository)
37+
38+
steps:
39+
- name: Checkout code
40+
uses: actions/checkout@v4
41+
42+
- name: Setup Python
43+
uses: actions/setup-python@v5
44+
with:
45+
python-version: "3.12"
46+
47+
- name: Cache pip dependencies
48+
uses: actions/cache@v4
49+
with:
50+
path: ~/.cache/pip
51+
key: ${{ runner.os }}-pip-integration-${{ hashFiles('requirements-dev.txt', 'requirements-test.txt') }}
52+
restore-keys: |
53+
${{ runner.os }}-pip-integration-
54+
${{ runner.os }}-pip-
55+
56+
- name: Install dependencies
57+
run: |
58+
pip install -r requirements-dev.txt
59+
pip install -r requirements-test.txt
60+
61+
- name: Authenticate to Google Cloud
62+
uses: google-github-actions/auth@v2
63+
with:
64+
credentials_json: ${{ secrets.GCP_SA_KEY }}
65+
66+
- name: Set up Cloud SDK
67+
uses: google-github-actions/setup-gcloud@v2
68+
69+
- name: Run integration tests
70+
env:
71+
CI: "true"
72+
# Extract from service account JSON automatically
73+
GOOGLE_CLOUD_PROJECT: ${{ fromJson(secrets.GCP_SA_KEY).project_id }}
74+
DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT: ${{ fromJson(secrets.GCP_SA_KEY).client_email }}
75+
# Infrastructure-specific secrets
76+
GOOGLE_CLOUD_REGION: ${{ secrets.GCP_REGION || 'us-central1' }}
77+
DATAPROC_SPARK_CONNECT_SUBNET: ${{ secrets.GCP_SUBNET }}
78+
DATAPROC_SPARK_CONNECT_AUTH_TYPE: "SERVICE_ACCOUNT"
79+
run: |
80+
python -m pytest tests/integration/ -v --tb=short -x

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
[tool.pyink]
22
line-length = 80 # (default is 88)
33
pyink-indentation = 4 # (default is 4)
4+
5+
[tool.pytest.ini_options]
6+
markers = [
7+
"integration: marks tests as integration tests",
8+
"ci_safe: marks tests that work in CI environment",
9+
]

tests/integration/test_session.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
TerminateSessionRequest,
3333
)
3434
from pyspark.errors.exceptions import connect as connect_exceptions
35-
from pyspark.sql import functions as F
3635
from pyspark.sql.types import StringType
3736

3837

@@ -49,9 +48,28 @@ def test_project():
4948
return os.getenv("GOOGLE_CLOUD_PROJECT")
5049

5150

51+
def is_ci_environment():
52+
"""Detect if running in CI environment."""
53+
return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"
54+
55+
5256
@pytest.fixture
5357
def auth_type(request):
54-
return getattr(request, "param", "SERVICE_ACCOUNT")
58+
"""Auto-detect authentication type based on environment.
59+
60+
CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT
61+
Local environment: Uses END_USER_CREDENTIALS
62+
Test parametrization can still override this default.
63+
"""
64+
# Allow test parametrization to override
65+
if hasattr(request, "param"):
66+
return request.param
67+
68+
# Auto-detect based on environment
69+
if is_ci_environment():
70+
return "SERVICE_ACCOUNT"
71+
else:
72+
return "END_USER_CREDENTIALS"
5573

5674

5775
@pytest.fixture
@@ -113,23 +131,29 @@ def session_template_controller_client(test_client_options):
113131

114132
@pytest.fixture
115133
def connect_session(test_project, test_region, os_environment):
116-
return (
134+
session = (
117135
DataprocSparkSession.builder.projectId(test_project)
118136
.location(test_region)
119137
.getOrCreate()
120138
)
139+
yield session
140+
# Clean up the session after each test to prevent resource conflicts
141+
try:
142+
session.stop()
143+
except Exception:
144+
# Ignore cleanup errors to avoid masking the actual test failure
145+
pass
121146

122147

123148
@pytest.fixture
124149
def session_name(test_project, test_region, connect_session):
125150
return f"projects/{test_project}/locations/{test_region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
126151

127152

128-
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
129153
def test_create_spark_session_with_default_notebook_behavior(
130154
auth_type, connect_session, session_name, session_controller_client
131155
):
132-
"""Test creating a Spark session with default notebook behavior using end user credentials."""
156+
"""Test creating a Spark session with default notebook behavior using auto-detected authentication."""
133157
get_session_request = GetSessionRequest()
134158
get_session_request.name = session_name
135159
session = session_controller_client.get_session(get_session_request)
@@ -352,7 +376,10 @@ def test_create_spark_session_with_session_template_and_user_provided_dataproc_c
352376
reason="Skipping PyPI package installation test since it's not supported yet"
353377
)
354378
def test_add_artifacts_pypi_package():
355-
"""Test adding PyPI packages as artifacts to a Spark session."""
379+
"""Test adding PyPI packages as artifacts to a Spark session.
380+
381+
Note: Skipped in CI due to infrastructure issues with PyPI package installation.
382+
"""
356383
connect_session = DataprocSparkSession.builder.getOrCreate()
357384
from pyspark.sql.connect.functions import udf, sum
358385
from pyspark.sql.types import IntegerType
@@ -380,35 +407,38 @@ def generate_random2(row) -> int:
380407

381408
def test_sql_functions(connect_session):
382409
"""Test basic SQL functions like col(), sum(), count(), etc."""
410+
# Import SparkConnect-compatible functions
411+
from pyspark.sql.connect.functions import col, sum, count
412+
383413
# Create a test DataFrame
384414
df = connect_session.createDataFrame(
385415
[(1, "Alice", 100), (2, "Bob", 200), (3, "Charlie", 150)],
386416
["id", "name", "amount"],
387417
)
388418

389419
# Test col() function
390-
result_col = df.select(F.col("name")).collect()
420+
result_col = df.select(col("name")).collect()
391421
assert len(result_col) == 3
392422
assert result_col[0]["name"] == "Alice"
393423

394424
# Test aggregation functions
395-
sum_result = df.select(F.sum("amount")).collect()[0][0]
425+
sum_result = df.select(sum("amount")).collect()[0][0]
396426
assert sum_result == 450
397427

398-
count_result = df.select(F.count("id")).collect()[0][0]
428+
count_result = df.select(count("id")).collect()[0][0]
399429
assert count_result == 3
400430

401431
# Test with where clause using col()
402-
filtered_df = df.where(F.col("amount") > 150)
432+
filtered_df = df.where(col("amount") > 150)
403433
filtered_count = filtered_df.count()
404434
assert filtered_count == 1
405435

406436
# Test multiple column operations
407437
df_with_calc = df.select(
408-
F.col("id"),
409-
F.col("name"),
410-
F.col("amount"),
411-
(F.col("amount") * 0.1).alias("tax"),
438+
col("id"),
439+
col("name"),
440+
col("amount"),
441+
(col("amount") * 0.1).alias("tax"),
412442
)
413443
tax_results = df_with_calc.collect()
414444
assert tax_results[0]["tax"] == 10.0
@@ -418,6 +448,9 @@ def test_sql_functions(connect_session):
418448

419449
def test_sql_udf(connect_session):
420450
"""Test SQL UDF registration and usage."""
451+
# Import SparkConnect-compatible functions
452+
from pyspark.sql.connect.functions import col, udf
453+
421454
# Create a test DataFrame
422455
df = connect_session.createDataFrame(
423456
[(1, "hello"), (2, "world"), (3, "spark")], ["id", "text"]
@@ -431,9 +464,9 @@ def uppercase_func(text):
431464
return text.upper() if text else None
432465

433466
# Test UDF with DataFrame API
434-
uppercase_udf = F.udf(uppercase_func, StringType())
467+
uppercase_udf = udf(uppercase_func, StringType())
435468
df_with_udf = df.select(
436-
"id", "text", uppercase_udf(F.col("text")).alias("upper_text")
469+
"id", "text", uppercase_udf(col("text")).alias("upper_text")
437470
)
438471
df_result = df_with_udf.collect()
439472

@@ -444,7 +477,6 @@ def uppercase_func(text):
444477
connect_session.sql("DROP VIEW IF EXISTS test_table")
445478

446479

447-
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
448480
def test_session_reuse_with_custom_id(
449481
auth_type,
450482
test_project,
@@ -453,7 +485,8 @@ def test_session_reuse_with_custom_id(
453485
os_environment,
454486
):
455487
"""Test the real-world session reuse scenario: create → terminate → recreate with same ID."""
456-
custom_session_id = "ml-pipeline-session"
488+
# Use a randomized session ID to avoid conflicts between test runs
489+
custom_session_id = f"ml-pipeline-session-{uuid.uuid4().hex[:8]}"
457490

458491
# Stop any existing session first to ensure clean state
459492
if DataprocSparkSession._active_s8s_session_id:
@@ -465,9 +498,12 @@ def test_session_reuse_with_custom_id(
465498
pass
466499

467500
# PHASE 1: Create initial session with custom ID
468-
spark1 = DataprocSparkSession.builder.dataprocSessionId(
469-
custom_session_id
470-
).getOrCreate()
501+
spark1 = (
502+
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
503+
.projectId(test_project)
504+
.location(test_region)
505+
.getOrCreate()
506+
)
471507

472508
# Verify session is created with custom ID
473509
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
@@ -482,9 +518,12 @@ def test_session_reuse_with_custom_id(
482518
# Clear cache to force session lookup
483519
DataprocSparkSession._default_session = None
484520

485-
spark2 = DataprocSparkSession.builder.dataprocSessionId(
486-
custom_session_id
487-
).getOrCreate()
521+
spark2 = (
522+
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
523+
.projectId(test_project)
524+
.location(test_region)
525+
.getOrCreate()
526+
)
488527

489528
# Should reuse the same active session
490529
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
@@ -495,7 +534,7 @@ def test_session_reuse_with_custom_id(
495534
result2 = df2.count()
496535
assert result2 == 1
497536

498-
# PHASE 3: Terminate session explicitly
537+
# PHASE 3: Stop should not terminate named session
499538
spark2.stop()
500539

501540
# PHASE 4: Recreate with same ID - this tests the cleanup and recreation logic
@@ -504,16 +543,19 @@ def test_session_reuse_with_custom_id(
504543
DataprocSparkSession._active_s8s_session_id = None
505544
DataprocSparkSession._active_s8s_session_uuid = None
506545

507-
spark3 = DataprocSparkSession.builder.dataprocSessionId(
508-
custom_session_id
509-
).getOrCreate()
546+
spark3 = (
547+
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
548+
.projectId(test_project)
549+
.location(test_region)
550+
.getOrCreate()
551+
)
510552

511-
# Should be a new session with same ID but different UUID
553+
# Should be a same session and same ID
512554
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
513555
third_session_uuid = spark3._active_s8s_session_uuid
514556

515-
# Should be different UUID (new session instance)
516-
assert third_session_uuid != first_session_uuid
557+
# Should be same UUID
558+
assert third_session_uuid == first_session_uuid
517559

518560
# Test functionality on recreated session
519561
df3 = spark3.createDataFrame([(3, "recreated")], ["id", "stage"])
@@ -546,7 +588,6 @@ def test_session_id_validation_in_integration(
546588
assert builder._custom_session_id == valid_id
547589

548590

549-
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
550591
def test_sparksql_magic_library_available(connect_session):
551592
"""Test that sparksql-magic library can be imported and loaded."""
552593
pytest.importorskip(
@@ -580,7 +621,6 @@ def test_sparksql_magic_library_available(connect_session):
580621
assert data[0]["test_column"] == "integration_test"
581622

582623

583-
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
584624
def test_sparksql_magic_with_dataproc_session(connect_session):
585625
"""Test that sparksql-magic works with registered DataprocSparkSession."""
586626
pytest.importorskip(

0 commit comments

Comments
 (0)