diff --git a/google/cloud/dataproc_magics/__init__.py b/google/cloud/dataproc_magics/__init__.py new file mode 100644 index 0000000..b695f33 --- /dev/null +++ b/google/cloud/dataproc_magics/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.1.0" + + +from google.cloud import storage +from ._internal import magic + + +_original_pip = None + + +def load_ipython_extension(ipython): + """Called by IPython when this module is loaded as an IPython ext.""" + global _original_pip + _original_pip = ipython.find_magic("pip") + + if _original_pip: + magics = magic.DataprocMagics( + shell=ipython, + original_pip=_original_pip, + gcs_client=storage.Client(), + ) + ipython.register_magics(magics) + + +def unload_ipython_extension(ipython): + """Called by IPython when this module is unloaded as an IPython ext.""" + global _original_pip + if _original_pip: + ipython.register_magic_function( + _original_pip, magic_kind="line", magic_name="pip" + ) + _original_pip = None + + +__all__ = [ + "__version__", + "load_ipython_extension", + "unload_ipython_extension", +] diff --git a/google/cloud/dataproc_magics/_internal/__init__.py b/google/cloud/dataproc_magics/_internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/google/cloud/dataproc_magics/_internal/dl.py b/google/cloud/dataproc_magics/_internal/dl.py new file mode 100644 index 0000000..42fca40 --- /dev/null +++ b/google/cloud/dataproc_magics/_internal/dl.py @@ -0,0 +1,66 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for downloading files from GCS.""" + +import os +import shutil +import tempfile + +from google.cloud import storage + + +class GcsDownloader: + """Helper for downloading files from GCS. + + An instance is a single-use context manager that downloads all its temporary + files to a per-instance temporary directory under its config's tmpdir. + """ + + def __init__(self, client: storage.Client, tmpdir: str | None): + self._client = client + self._base_tmpdir = tmpdir + # Per-context tmpdir inside base. + self._tmpdir: str | None = None + + def __enter__(self): + if self._tmpdir is not None: + raise RuntimeError(f"{type(self)} has already been entered") + self._tmpdir = tempfile.mkdtemp(dir=self._base_tmpdir) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._tmpdir is None: + raise RuntimeError(f"{type(self)} has not been entered") + print(f"Removing GCS temporary download directory {self._tmpdir}") + try: + shutil.rmtree(self._tmpdir) + except OSError as e: + print( + f"Warning: Failed to remove temporary directory {self._tmpdir}: {e}" + ) + self._tmpdir = None + + def download(self, url: str): + """Download the given GCS URL to a temporary file.""" + if self._tmpdir is None: + raise RuntimeError("Cannot download outside of a 'with' block") + blob = storage.Blob.from_string(url, self._client) + if blob.name is None: + raise ValueError(f"Couldn't parse blob from URL: {url}") + blob_name = blob.name.rsplit("/", 1)[-1] + tmpfile = os.path.join(self._tmpdir, blob_name) + print(f"Downloading {url} to {tmpfile}") + blob.download_to_filename(tmpfile) + return tmpfile diff --git a/google/cloud/dataproc_magics/_internal/magic.py b/google/cloud/dataproc_magics/_internal/magic.py new file mode 100644 index 0000000..64e6525 --- /dev/null +++ b/google/cloud/dataproc_magics/_internal/magic.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataproc magic implementations.""" + +from collections.abc import Callable +import shlex + +from google.cloud import storage +from IPython.core import magic +import traitlets + +from . import dl + + +@magic.magics_class +class DataprocMagics(magic.Magics): + """Dataproc magics class.""" + + tmpdir = traitlets.Unicode( + default_value=None, + allow_none=True, + help="Temporary directory for downloads; defaults to system temp dir", + ).tag(config=True) + + def __init__( + self, + shell, + original_pip: Callable[[str], None], + gcs_client: storage.Client, + **kwargs, + ): + super().__init__(shell, **kwargs) + self._original_pip = original_pip + self._gcs_client = gcs_client + + def _transform_line(self, line: str, downloader: dl.GcsDownloader) -> str: + new_args = [] + for arg in shlex.split(line): + gcs_url_start = arg.find("gs://") + # gs:// found either at the beginning of an arg, or anywhere in an + # option/value starting with - (short or long form). + if gcs_url_start != -1 and (arg[0] == "-" or gcs_url_start == 0): + prefix = arg[:gcs_url_start] + url = arg[gcs_url_start:] + new_args.append(prefix + downloader.download(url)) + else: + new_args.append(arg) + return shlex.join(new_args) + + @magic.line_magic + def pip(self, line: str) -> None: + if "gs://" in line: + with dl.GcsDownloader(self._gcs_client, self.tmpdir) as downloader: + new_line = self._transform_line(line, downloader) + self._original_pip(new_line) + else: + self._original_pip(line) diff --git a/requirements-dev.txt b/requirements-dev.txt index 5cf7026..5b161af 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ google-api-core>=2.19 google-cloud-dataproc>=5.18 +google-cloud-storage>=3.7.0 ipython~=9.1 ipywidgets>=8.0.0 packaging>=20.0 @@ -9,3 +10,4 @@ setuptools>=72.0 sparksql-magic>=0.0.3 tqdm>=4.67 websockets>=14.0 +jupyter-kernel-test diff --git a/requirements-test.txt b/requirements-test.txt index 3443484..672beb8 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,4 @@ pytest>=8.0 pytest-xdist>=3.0 +jupyter-client>=8.0 +nbformat>=5.10 \ No newline at end of file diff --git a/tests/integration/dataproc_magics/__init__.py b/tests/integration/dataproc_magics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/dataproc_magics/test_magics.py b/tests/integration/dataproc_magics/test_magics.py new file mode 100644 index 0000000..80f7249 --- /dev/null +++ b/tests/integration/dataproc_magics/test_magics.py @@ -0,0 +1,152 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import textwrap +import subprocess +import tempfile +import shutil +import unittest + +from jupyter_kernel_test import KernelTests +from jupyter_client.kernelspec import KernelSpecManager +from jupyter_client.manager import KernelManager + +from google.cloud import storage + + +class TestDataprocMagics(KernelTests): + kernel_name = "python3" # Will be updated in setUp + + @classmethod + def setUpClass(cls): + # Override to prevent default kernel from starting. + # We start a new kernel for each test method. + pass + + @classmethod + def tearDownClass(cls): + # Override to prevent default kernel from being shut down. + pass + + def _get_requirements_file(self): + bucket_name = os.environ.get("DATAPROC_TEST_BUCKET") + if not bucket_name: + self.skipTest("DATAPROC_TEST_BUCKET environment variable not set") + + object_name = "test-magics-requirements.txt" + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(object_name) + + # Download and verify content + downloaded_content = blob.download_as_text() + self.assertEqual(downloaded_content, "humanize==4.14.0\n") + + return bucket_name, object_name + + def setUp(self): + self.temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-test-") + venv_dir = os.path.join(self.temp_dir, "venv") + + # Create venv + subprocess.run( + [sys.executable, "-m", "venv", venv_dir], + check=True, + capture_output=True, + ) + + # Install deps + pip_exe = os.path.join(venv_dir, "bin", "pip") + subprocess.run( + [pip_exe, "install", "ipykernel", "google-cloud-storage"], + check=True, + capture_output=True, + ) + subprocess.run( + [pip_exe, "install", "-e", "."], check=True, capture_output=True + ) + + # Install kernelspec + python_exe = os.path.join(venv_dir, "bin", "python") + self.kernel_name = f"temp-kernel-{os.path.basename(self.temp_dir)}" + + subprocess.run( + [ + python_exe, + "-m", + "ipykernel", + "install", + "--name", + self.kernel_name, + "--prefix", + self.temp_dir, + ], + check=True, + capture_output=True, + ) + + kernel_dir = os.path.join(self.temp_dir, "share", "jupyter", "kernels") + + # Start kernel + ksm = KernelSpecManager(kernel_dirs=[kernel_dir]) + self.km = KernelManager( + kernel_spec_manager=ksm, kernel_name=self.kernel_name + ) + self.km.start_kernel() + + self.kc = self.km.client() + self.kc.load_connection_file() + self.kc.start_channels() + self.kc.wait_for_ready() + + def tearDown(self): + self.kc.stop_channels() + self.km.shutdown_kernel() + shutil.rmtree(self.temp_dir) + + def test_pip_install_from_gcs(self): + bucket_name, object_name = self._get_requirements_file() + + # Load extension + reply, output_msgs = self.execute_helper( + "%load_ext google.cloud.dataproc_magics" + ) + # Assert that there are no stream messages (stdout/stderr) + self.assertFalse( + any(msg["msg_type"] == "stream" for msg in output_msgs) + ) + + # Pip install + install_cmd = f"%pip install -r gs://{bucket_name}/{object_name}" + self.assert_in_stdout( + install_cmd, "Successfully installed humanize-4.14.0" + ) + + # Import and use humanize + code = textwrap.dedent( + """ + import humanize + print(humanize.intcomma(12345)) + """ + ) + # assert_stdout adds a newline to the expected output if it's not present, + # because print statements typically add a newline. + self.assert_stdout(code, "12,345\n") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/dataproc_magics/test_magics_with_ipython.py b/tests/integration/dataproc_magics/test_magics_with_ipython.py new file mode 100644 index 0000000..9ea39a2 --- /dev/null +++ b/tests/integration/dataproc_magics/test_magics_with_ipython.py @@ -0,0 +1,199 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import subprocess +import sys +import tempfile +import textwrap +from dataclasses import dataclass, field +from queue import Empty + +import pytest +from google.cloud import storage +from jupyter_client.manager import KernelManager + +# Mark all tests in this file as integration tests +pytestmark = pytest.mark.integration + + +@dataclass +class ExecutionResult: + """A dataclass to hold the results of executing a notebook cell.""" + + success: bool + stdout: str = "" + stderr: str = "" + errors: list[str] = field(default_factory=list) + + +class IsolatedIPythonSession: + """ + Manages a fully isolated IPython kernel in a temporary venv. + Provides an execute() method to run code and capture output, + mimicking a notebook environment. + """ + + def __init__(self): + self._temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-test-") + self._venv_dir = os.path.join(self._temp_dir, "venv") + self._kernel_manager = None + self._kernel_client = None + + self._setup_venv() + self._start_kernel() + + def _setup_venv(self): + subprocess.run( + [sys.executable, "-m", "venv", self._venv_dir], + check=True, + capture_output=True, + ) + pip_exe = os.path.join(self._venv_dir, "bin", "pip") + # jupyter_client is needed to manage the kernel + subprocess.run( + [ + pip_exe, + "install", + "ipykernel", + "google-cloud-storage", + "jupyter-client", + ], + check=True, + capture_output=True, + ) + subprocess.run( + [pip_exe, "install", "-e", "."], check=True, capture_output=True + ) + + def _start_kernel(self): + self._kernel_manager = KernelManager( + kernel_name="python3", + kernel_spec_manager=None, # Not needed when starting with an explicit python path + ipython_path=os.path.join(self._venv_dir, "bin", "python"), + ) + self._kernel_manager.start_kernel() + self._kernel_client = self._kernel_manager.client() + self._kernel_client.start_channels() + self._kernel_client.wait_for_ready() + + # Load the dataproc magics extension for all tests in this session + self.execute("%load_ext google.cloud.dataproc_magics") + + def execute(self, code: str, timeout: int = 60) -> ExecutionResult: + """Executes a cell and returns the collected output.""" + msg_id = self._kernel_client.execute(code) + stdout, stderr, errors = [], [], [] + + while True: + try: + msg = self._kernel_client.get_iopub_msg(timeout=timeout) + msg_type = msg["header"]["msg_type"] + content = msg["content"] + + if ( + msg_type == "status" + and content["execution_state"] == "idle" + ): + break # Execution is done + + if msg_type == "stream": + if content["name"] == "stdout": + stdout.append(content["text"]) + else: + stderr.append(content["text"]) + elif msg_type == "error": + errors.append("\n".join(content["traceback"])) + + except Empty: + # Timed out waiting for messages. + return ExecutionResult( + success=False, stderr="Execution timed out." + ) + + # Final reply from the shell channel + reply = self._kernel_client.get_shell_msg(timeout=timeout) + success = reply["content"]["status"] == "ok" + + return ExecutionResult( + success=success, + stdout="".join(stdout), + stderr="".join(stderr), + errors=errors, + ) + + def close(self): + """Shutdown kernel and cleanup the venv.""" + if self._kernel_client: + self._kernel_client.stop_channels() + if self._kernel_manager and self._kernel_manager.is_alive(): + self._kernel_manager.shutdown_kernel() + if self._temp_dir and os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir) + + +@pytest.fixture(scope="function") +def ipython_session(): + """Fixture to provide a clean, isolated IPython session for a single test.""" + session = IsolatedIPythonSession() + yield session + session.close() + + +@pytest.fixture(scope="module") +def requirements_gcs_path(): + """Fixture to get the GCS path for a test requirements file.""" + bucket_name = os.environ.get("DATAPROC_TEST_BUCKET") + if not bucket_name: + pytest.skip("DATAPROC_TEST_BUCKET environment variable not set") + # ... (rest of the GCS fixture is the same) + object_name = "test-magics-requirements.txt" + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(object_name) + + try: + content = blob.download_as_text() + assert "humanize" in content + except Exception as e: + pytest.fail( + f"Failed to download/verify GCS file: gs://{bucket_name}/{object_name}. Error: {e}" + ) + + return f"gs://{bucket_name}/{object_name}" + + +def test_pip_install_from_gcs_isolated(ipython_session, requirements_gcs_path): + """ + Tests installing a package from GCS in a fully isolated session, + with clean, cell-by-cell execution. + """ + # 1. Run the pip install command using the magic. + install_cmd = f"%pip install -r {requirements_gcs_path}" + result = ipython_session.execute(install_cmd) + + assert result.success, f"Magic command failed: {result.stderr}" + assert "Successfully installed humanize" in result.stdout + + # 2. Verify the installed package can be imported and used. + verify_code = textwrap.dedent( + """ + import humanize + print(humanize.intcomma(12345)) + """ + ) + result = ipython_session.execute(verify_code) + assert result.success, f"Verification code failed: {result.stderr}" + assert "12,345" in result.stdout diff --git a/tests/integration/dataproc_magics/test_magics_with_papermill.py b/tests/integration/dataproc_magics/test_magics_with_papermill.py new file mode 100644 index 0000000..5f664a7 --- /dev/null +++ b/tests/integration/dataproc_magics/test_magics_with_papermill.py @@ -0,0 +1,251 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import hashlib +import json +import os +import shutil +import subprocess +import sys +import tempfile +import urllib.request +from collections.abc import Generator, Iterable, Mapping + +import nbformat +import pytest +from google.cloud import storage +from nbformat import v4 as nbformat_v4 + + +@dataclasses.dataclass +class PapermillEnv: + root_dir: str + exe: str + + +def _generate_test_notebook( + cells: Mapping[str, str], parameter_names: Iterable[str] +) -> nbformat.NotebookNode: + """Generates a notebook object from a dict of code cell name to contents.""" + py_version = ".".join(map(str, sys.version_info[:3])) + nb = nbformat_v4.new_notebook() + nb.metadata = { + "kernelspec": { + "display_name": f"Python {py_version}", + "language": "python", + "name": "python3", + }, + "language_info": {"name": "python", "version": py_version}, + } + + parameter_defs = "\n".join(f'{name} = ""' for name in parameter_names) + nb.cells.append( + nbformat_v4.new_code_cell( + parameter_defs, metadata={"tags": ["parameters"]} + ) + ) + + nb.cells.extend( + nbformat_v4.new_code_cell(code, metadata={"tags": [name]}) + for name, code in cells.items() + ) + + return nb + + +def _run_notebook( + pm_env: PapermillEnv, + cells: Mapping[str, str], + parameters: Mapping[str, str], +) -> Mapping[str, str]: + """Run the given cells in a notebook using papermill. + + Args: + pm_env: Papermill environment. + cells: Mapping of cell names to contents. + parameters: Mapping of papermill parameter names to values. + + Returns: + Mapping of cell names to cell stdout. + """ + # Generate and write the notebook to a temporary file. + notebook_obj = _generate_test_notebook(cells, parameters.keys()) + input_nb_path = os.path.join(pm_env.root_dir, "input.ipynb") + with open(input_nb_path, "w") as f: + nbformat.write(notebook_obj, f) + + # Run the notebook with papermill. + output_nb_path = os.path.join(pm_env.root_dir, "output.ipynb") + print("Executing notebook with papermill") + cmd = [pm_env.exe, input_nb_path, output_nb_path] + for key, value in parameters.items(): + cmd.extend(["-p", key, value]) + + result = subprocess.run(cmd, text=True) + assert ( + result.returncode == 0 + ), f"Papermill execution failed with exit code {result.returncode}.\n" + + # Parse the output notebook and extract stdout from each tagged cell. + with open(output_nb_path) as f: + nb = nbformat.read(f, as_version=4) + + results = {} + for cell in nb.cells: + tags = cell.metadata.get("tags", []) + if not tags or "parameters" in tags: + continue + # Assumes one tag per cell, consistent with how notebook is generated + # above. + name = tags[0] + + stdouts = [ + o.text + for o in cell.outputs + if o.output_type == "stream" and o.name == "stdout" + ] + results[name] = "".join(stdouts) + + return results + + +@pytest.fixture(scope="function") +def pm_env() -> Generator[PapermillEnv, None, None]: + """Fixture to create a fresh venv with papermill for each test.""" + temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-pm-") + venv_dir = os.path.join(temp_dir, "venv") + + try: + print(f"Creating venv to run papermill in {venv_dir}") + subprocess.check_call([sys.executable, "-m", "venv", venv_dir]) + pip_exe = os.path.join(venv_dir, "bin", "pip") + print(f"Installing notebook dependencies in {venv_dir}") + subprocess.check_call( + [ + pip_exe, + "install", + "papermill", + "google-cloud-storage", + "ipykernel", + ] + ) + subprocess.check_call([pip_exe, "install", "-e", "."]) + + yield PapermillEnv( + root_dir=temp_dir, + exe=os.path.join(venv_dir, "bin", "papermill"), + ) + + finally: + print(f"Cleaning up temp dir {temp_dir}") + shutil.rmtree(temp_dir) + + +@pytest.fixture(scope="module") +def test_bucket() -> str: + """Fixture to get the GCS test bucket name from environment variable.""" + bucket_name = os.environ.get("DATAPROC_TEST_BUCKET") + if not bucket_name: + pytest.fail("DATAPROC_TEST_BUCKET environment variable not set") + return bucket_name + + +@pytest.fixture(scope="module") +def gcs_requirements(test_bucket: str) -> str: + """Fixture to get the GCS path for a test requirements file.""" + # TODO: Consider whether we should handle uploading here. Would be annoying + # to manage temp buckets, GCing old versions, etc. For a single requirements + # file the current approach is simpler. + object_name = "test-magics-requirements.txt" + url = f"gs://{test_bucket}/{object_name}" + print(f"Validating {url} contents") + + storage_client = storage.Client() + bucket = storage_client.bucket(test_bucket) + blob = bucket.blob(object_name) + try: + content = blob.download_as_text() + assert content == "humanize==4.14.0\n" + except Exception as e: + pytest.fail(f"Failed to download/verify GCS file. Error: {e}") + return url + + +@pytest.fixture(scope="module") +def gcs_wheel(test_bucket: str) -> str: + """Fixture to get GCS path for a test wheel and verify its hash.""" + pkg_name = "humanize" + pkg_version = "4.14.0" + file_name = f"{pkg_name}-{pkg_version}-py3-none-any.whl" + + # Get expected hash from PyPI + pypi_url = f"https://pypi.org/pypi/{pkg_name}/{pkg_version}/json" + with urllib.request.urlopen(pypi_url) as response: + pypi_data = json.load(response) + wheel_info = next( + (url for url in pypi_data["urls"] if url["filename"] == file_name), None + ) + if not wheel_info: + pytest.fail(f"Could not find {file_name} in PyPI JSON response.") + expected_hash = wheel_info["digests"]["sha256"] + + # Get GCS file and check hash + url = f"gs://{test_bucket}/{file_name}" + print(f"Validating {url} contents") + + storage_client = storage.Client() + bucket = storage_client.bucket(test_bucket) + blob = bucket.blob(file_name) + try: + content = blob.download_as_bytes() + actual_hash = hashlib.sha256(content).hexdigest() + assert actual_hash == expected_hash + except Exception as e: + pytest.fail(f"Failed to download/verify GCS file. Error: {e}") + + return url + + +@pytest.mark.parametrize( + "pip_line", + [ + pytest.param("%pip install -r {gcs_requirements}", id="r_space"), + pytest.param("%pip install -r{gcs_requirements}", id="r_no_space"), + pytest.param("%pip install {gcs_wheel}", id="wheel"), + ], +) +def test_pip_install_from_gcs( + pm_env: PapermillEnv, + gcs_requirements: str, + gcs_wheel: str, + pip_line: str, +): + test_cells = { + "load_ext": "%load_ext google.cloud.dataproc_magics", + "pip_install": pip_line, + "code": "import humanize\nprint(humanize.intcomma(12345))", + } + parameters = { + "gcs_requirements": gcs_requirements, + "gcs_wheel": gcs_wheel, + } + + results = _run_notebook(pm_env, test_cells, parameters) + + install_output = results["pip_install"] + assert "Downloading gs://" in install_output + assert "Successfully installed humanize-4.14.0" in install_output + + assert results["code"] == "12,345\n" diff --git a/tests/unit/dataproc_magics/__init__.py b/tests/unit/dataproc_magics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/dataproc_magics/test_downloader.py b/tests/unit/dataproc_magics/test_downloader.py new file mode 100644 index 0000000..ec64fc8 --- /dev/null +++ b/tests/unit/dataproc_magics/test_downloader.py @@ -0,0 +1,73 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from unittest import mock + +from google.cloud.dataproc_magics._internal import dl + + +class TestGcsDownloader(unittest.TestCase): + + def test_download(self): + client = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_blob.name = "my-package-0.1.0.whl" + + with ( + # Mocking prevents files from being downloaded, but the context + # manager still wants to create a new directory under tmpdir. + tempfile.TemporaryDirectory() as tmpdir, + mock.patch( + "google.cloud.dataproc_magics._internal.dl.storage.Blob.from_string", + return_value=mock_blob, + ) as from_string, + dl.GcsDownloader(client, tmpdir) as downloader, + ): + gcs_url = "gs://my-bucket/my-package-0.1.0.whl" + assert downloader._tmpdir is not None + expected = os.path.join(downloader._tmpdir, "my-package-0.1.0.whl") + actual = downloader.download(gcs_url) + from_string.assert_called_once_with(gcs_url, client) + mock_blob.download_to_filename.assert_called_once_with(expected) + self.assertEqual(actual, expected) + + def test_download_outside_with_block(self): + downloader = dl.GcsDownloader(mock.MagicMock(), None) + with self.assertRaises(RuntimeError) as raised: + downloader.download("gs://my-bucket/my-package-0.1.0.whl") + self.assertEqual( + "Cannot download outside of a 'with' block", + str(raised.exception), + ) + + def test_cleanup(self): + with tempfile.TemporaryDirectory() as tmpdir: + with dl.GcsDownloader(mock.MagicMock(), tmpdir) as _: + self.assertTrue(os.path.isdir(tmpdir)) + contents = os.listdir(tmpdir) + self.assertEqual( + 1, + len(contents), + msg=f"Expected 1 file in {tmpdir}, got {contents}", + ) + inner_tmpdir = os.path.join(tmpdir, contents[0]) + self.assertTrue(os.path.isdir(inner_tmpdir)) + self.assertFalse(os.path.exists(tmpdir)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataproc_magics/test_magics.py b/tests/unit/dataproc_magics/test_magics.py new file mode 100644 index 0000000..29f9db5 --- /dev/null +++ b/tests/unit/dataproc_magics/test_magics.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock +import shlex + +from google.cloud.dataproc_magics._internal import magic +from google.cloud.dataproc_magics._internal import dl + + +class TestDataprocMagics(unittest.TestCase): + + def setUp(self): + self.mock_original_pip = mock.MagicMock() + self.mock_gcs_client = mock.MagicMock() + self.magics = magic.DataprocMagics( + shell=None, + original_pip=self.mock_original_pip, + gcs_client=self.mock_gcs_client, + ) + + def _mock_downloader(self, gcs_map): + downloader_mock = mock.MagicMock(spec=dl.GcsDownloader) + downloader_mock.download.side_effect = lambda url: gcs_map.get(url, url) + return downloader_mock + + def test_transform_line_with_gcs_url(self): + downloader_mock = self._mock_downloader( + {"gs://my-bucket/my-package-0.1.0.whl": "/tmp/my-package-0.1.0.whl"} + ) + line = "install gs://my-bucket/my-package-0.1.0.whl" + result = self.magics._transform_line(line, downloader_mock) + expected_line = "install /tmp/my-package-0.1.0.whl" + self.assertEqual(result, expected_line) + + def test_transform_line_without_gcs_url(self): + downloader_mock = self._mock_downloader({}) + line = "install requests" + result = self.magics._transform_line(line, downloader_mock) + self.assertEqual(result, line) + + def test_transform_line_with_mixed_args(self): + gcs_map = { + "gs://my-bucket/pkg1.whl": "/tmp/pkg1.whl", + "gs://another-bucket/pkg2.tar.gz": "/tmp/pkg2.tar.gz", + } + downloader_mock = self._mock_downloader(gcs_map) + line = "install gs://my-bucket/pkg1.whl local-pkg.whl gs://another-bucket/pkg2.tar.gz" + result = self.magics._transform_line(line, downloader_mock) + expected_args = [ + "install", + "/tmp/pkg1.whl", + "local-pkg.whl", + "/tmp/pkg2.tar.gz", + ] + expected_line = shlex.join(expected_args) + self.assertEqual(result, expected_line) + + def test_transform_line_with_prefixed_gcs_url(self): + gcs_map = {"gs://my-bucket/reqs.txt": "/tmp/reqs.txt"} + downloader_mock = self._mock_downloader(gcs_map) + line = "install -rgs://my-bucket/reqs.txt" + result = self.magics._transform_line(line, downloader_mock) + expected_line = "install -r/tmp/reqs.txt" + self.assertEqual(result, expected_line) + + def test_transform_line_with_equals_prefixed_gcs_url(self): + gcs_map = {"gs://my-bucket/reqs.txt": "/tmp/reqs.txt"} + downloader_mock = self._mock_downloader(gcs_map) + line = "install --requirement=gs://my-bucket/reqs.txt" + result = self.magics._transform_line(line, downloader_mock) + expected_line = "install --requirement=/tmp/reqs.txt" + self.assertEqual(result, expected_line) + + def test_transform_line_with_multiple_prefixed_gcs_urls(self): + gcs_map = { + "gs://my-bucket/reqs.txt": "/tmp/reqs.txt", + "gs://another-bucket/constraint.txt": "/tmp/constraint.txt", + } + downloader_mock = self._mock_downloader(gcs_map) + args = [ + "install", + "--requirement=gs://my-bucket/reqs.txt", + "--constraint=gs://another-bucket/constraint.txt", + ] + result = self.magics._transform_line(" ".join(args), downloader_mock) + expected_args = [ + "install", + "--requirement=/tmp/reqs.txt", + "--constraint=/tmp/constraint.txt", + ] + expected_line = shlex.join(expected_args) + self.assertEqual(result, expected_line) + + def test_transform_line_with_gcs_url_and_other_args(self): + gcs_map = {"gs://my-bucket/reqs.txt": "/tmp/reqs.txt"} + downloader_mock = self._mock_downloader(gcs_map) + line = "install --verbose -rgs://my-bucket/reqs.txt other-pkg" + result = self.magics._transform_line(line, downloader_mock) + expected_args = [ + "install", + "--verbose", + "-r/tmp/reqs.txt", + "other-pkg", + ] + expected_line = shlex.join(expected_args) + self.assertEqual(result, expected_line) + + def test_transform_line_with_non_option_gs_not_at_start(self): + downloader_mock = self._mock_downloader({}) + line = "install bugs://my-bucket/foo" + result = self.magics._transform_line(line, downloader_mock) + self.assertEqual(result, line) + + def test_transform_line_with_gs_as_substring_of_url_scheme(self): + gcs_map = {"gs://my-bucket/foo": "/tmp/foo"} + downloader_mock = self._mock_downloader(gcs_map) + line = "install -rbugs://my-bucket/foo" + result = self.magics._transform_line(line, downloader_mock) + # This is arguably wrong: correct parsing is "-r bugs://...", which is + # some other custom URL which we should not be attempting to fetch. In + # practice, it would almost certainly fail at runtime regardless of + # whether we replace the argument: + # * If we didn't, pip would fail to fetch bugs://... + # * When we do, pip would almost certainly fail to read the file + # bu/tmp/reqs.txt + # Getting this 100% correct requires much more sophisticated command + # line parsing logic; this test case mostly just documents existing + # behavior. + self.assertEqual(result, "install -rbu/tmp/foo") + + +if __name__ == "__main__": + unittest.main()