diff --git a/google/cloud/dataproc_spark_connect/session.py b/google/cloud/dataproc_spark_connect/session.py index d1e3570..2fd2211 100644 --- a/google/cloud/dataproc_spark_connect/session.py +++ b/google/cloud/dataproc_spark_connect/session.py @@ -592,6 +592,16 @@ def getOrCreate(self) -> "DataprocSparkSession": session = PySparkSQLSession.builder.getOrCreate() return session # type: ignore + if self._project_id is None: + raise DataprocSparkConnectException( + f"Error while creating Dataproc Session: project ID is not set" + ) + + if self._region is None: + raise DataprocSparkConnectException( + f"Error while creating Dataproc Session: location is not set" + ) + # Handle custom session ID by setting it early and letting existing logic handle it if self._custom_session_id: self._handle_custom_session_id() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d75c8ce..703276f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1471,9 +1471,31 @@ def test_create_session_with_valid_notebook_id( ) self.stopSession(mock_session_controller_client_instance, session) + def test_create_session_without_project_id(self): + """Tests that an exception is raised when project ID is not provided.""" + os.environ.clear() + try: + DataprocSparkSession.builder.location("test-region").getOrCreate() + except DataprocSparkConnectException as e: + self.assertIn("project ID is not set", str(e)) + + def test_create_session_without_location(self): + """Tests that an exception is raised when location is not provided.""" + os.environ.clear() + try: + DataprocSparkSession.builder.projectId("test-project").getOrCreate() + except DataprocSparkConnectException as e: + self.assertIn("location is not set", str(e)) + class DataprocSparkConnectClientTest(unittest.TestCase): + def setUp(self): + self.original_environment = dict(os.environ) + os.environ.clear() + os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project" + os.environ["GOOGLE_CLOUD_REGION"] = "test-region" + @staticmethod def stopSession(mock_session_controller_client_instance, session): session_response = Session()