Skip to content

Commit 8da16ed

Browse files
authored
fix: Add validation when location or project ID are not provided (#168)
1 parent 7af8498 commit 8da16ed

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

google/cloud/dataproc_spark_connect/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,16 @@ def getOrCreate(self) -> "DataprocSparkSession":
592592
session = PySparkSQLSession.builder.getOrCreate()
593593
return session # type: ignore
594594

595+
if self._project_id is None:
596+
raise DataprocSparkConnectException(
597+
f"Error while creating Dataproc Session: project ID is not set"
598+
)
599+
600+
if self._region is None:
601+
raise DataprocSparkConnectException(
602+
f"Error while creating Dataproc Session: location is not set"
603+
)
604+
595605
# Handle custom session ID by setting it early and letting existing logic handle it
596606
if self._custom_session_id:
597607
self._handle_custom_session_id()

tests/unit/test_session.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,9 +1471,31 @@ def test_create_session_with_valid_notebook_id(
14711471
)
14721472
self.stopSession(mock_session_controller_client_instance, session)
14731473

1474+
def test_create_session_without_project_id(self):
1475+
"""Tests that an exception is raised when project ID is not provided."""
1476+
os.environ.clear()
1477+
try:
1478+
DataprocSparkSession.builder.location("test-region").getOrCreate()
1479+
except DataprocSparkConnectException as e:
1480+
self.assertIn("project ID is not set", str(e))
1481+
1482+
def test_create_session_without_location(self):
1483+
"""Tests that an exception is raised when location is not provided."""
1484+
os.environ.clear()
1485+
try:
1486+
DataprocSparkSession.builder.projectId("test-project").getOrCreate()
1487+
except DataprocSparkConnectException as e:
1488+
self.assertIn("location is not set", str(e))
1489+
14741490

14751491
class DataprocSparkConnectClientTest(unittest.TestCase):
14761492

1493+
def setUp(self):
1494+
self.original_environment = dict(os.environ)
1495+
os.environ.clear()
1496+
os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project"
1497+
os.environ["GOOGLE_CLOUD_REGION"] = "test-region"
1498+
14771499
@staticmethod
14781500
def stopSession(mock_session_controller_client_instance, session):
14791501
session_response = Session()

0 commit comments

Comments
 (0)