-
Notifications
You must be signed in to change notification settings - Fork 12
feat: Add new dataproc_magics module providing %pip install features
#175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dborowitz
wants to merge
1
commit into
GoogleCloudDataproc:main
Choose a base branch
from
dborowitz:magics
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| __version__ = "0.1.0" | ||
|
|
||
|
|
||
| from google.cloud import storage | ||
| from ._internal import magic | ||
|
|
||
|
|
||
| _original_pip = None | ||
|
|
||
|
|
||
| def load_ipython_extension(ipython): | ||
| """Called by IPython when this module is loaded as an IPython ext.""" | ||
| global _original_pip | ||
| _original_pip = ipython.find_magic("pip") | ||
|
|
||
| if _original_pip: | ||
| magics = magic.DataprocMagics( | ||
| shell=ipython, | ||
| original_pip=_original_pip, | ||
| gcs_client=storage.Client(), | ||
| ) | ||
| ipython.register_magics(magics) | ||
|
|
||
|
|
||
| def unload_ipython_extension(ipython): | ||
| """Called by IPython when this module is unloaded as an IPython ext.""" | ||
| global _original_pip | ||
| if _original_pip: | ||
| ipython.register_magic_function( | ||
| _original_pip, magic_kind="line", magic_name="pip" | ||
| ) | ||
| _original_pip = None | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "__version__", | ||
| "load_ipython_extension", | ||
| "unload_ipython_extension", | ||
| ] |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Utilities for downloading files from GCS.""" | ||
|
|
||
| import os | ||
| import shutil | ||
| import tempfile | ||
|
|
||
| from google.cloud import storage | ||
|
|
||
|
|
||
| class GcsDownloader: | ||
| """Helper for downloading files from GCS. | ||
|
|
||
| An instance is a single-use context manager that downloads all its temporary | ||
| files to a per-instance temporary directory under its config's tmpdir. | ||
| """ | ||
|
|
||
| def __init__(self, client: storage.Client, tmpdir: str | None): | ||
| self._client = client | ||
| self._base_tmpdir = tmpdir | ||
| # Per-context tmpdir inside base. | ||
| self._tmpdir: str | None = None | ||
|
|
||
| def __enter__(self): | ||
| if self._tmpdir is not None: | ||
| raise RuntimeError(f"{type(self)} has already been entered") | ||
| self._tmpdir = tempfile.mkdtemp(dir=self._base_tmpdir) | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc_val, exc_tb): | ||
| if self._tmpdir is None: | ||
| raise RuntimeError(f"{type(self)} has not been entered") | ||
| print(f"Removing GCS temporary download directory {self._tmpdir}") | ||
| try: | ||
| shutil.rmtree(self._tmpdir) | ||
| except OSError as e: | ||
| print( | ||
| f"Warning: Failed to remove temporary directory {self._tmpdir}: {e}" | ||
| ) | ||
| self._tmpdir = None | ||
|
|
||
| def download(self, url: str): | ||
| """Download the given GCS URL to a temporary file.""" | ||
| if self._tmpdir is None: | ||
| raise RuntimeError("Cannot download outside of a 'with' block") | ||
| blob = storage.Blob.from_string(url, self._client) | ||
| if blob.name is None: | ||
| raise ValueError(f"Couldn't parse blob from URL: {url}") | ||
| blob_name = blob.name.rsplit("/", 1)[-1] | ||
| tmpfile = os.path.join(self._tmpdir, blob_name) | ||
| print(f"Downloading {url} to {tmpfile}") | ||
| blob.download_to_filename(tmpfile) | ||
| return tmpfile |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Dataproc magic implementations.""" | ||
|
|
||
| from collections.abc import Callable | ||
| import shlex | ||
|
|
||
| from google.cloud import storage | ||
| from IPython.core import magic | ||
| import traitlets | ||
|
|
||
| from . import dl | ||
|
|
||
|
|
||
| @magic.magics_class | ||
| class DataprocMagics(magic.Magics): | ||
| """Dataproc magics class.""" | ||
|
|
||
| tmpdir = traitlets.Unicode( | ||
| default_value=None, | ||
| allow_none=True, | ||
| help="Temporary directory for downloads; defaults to system temp dir", | ||
| ).tag(config=True) | ||
|
|
||
| def __init__( | ||
| self, | ||
| shell, | ||
| original_pip: Callable[[str], None], | ||
| gcs_client: storage.Client, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(shell, **kwargs) | ||
| self._original_pip = original_pip | ||
| self._gcs_client = gcs_client | ||
|
|
||
| def _transform_line(self, line: str, downloader: dl.GcsDownloader) -> str: | ||
| new_args = [] | ||
| for arg in shlex.split(line): | ||
| gcs_url_start = arg.find("gs://") | ||
| # gs:// found either at the beginning of an arg, or anywhere in an | ||
| # option/value starting with - (short or long form). | ||
| if gcs_url_start != -1 and (arg[0] == "-" or gcs_url_start == 0): | ||
| prefix = arg[:gcs_url_start] | ||
| url = arg[gcs_url_start:] | ||
| new_args.append(prefix + downloader.download(url)) | ||
| else: | ||
| new_args.append(arg) | ||
| return shlex.join(new_args) | ||
|
|
||
| @magic.line_magic | ||
| def pip(self, line: str) -> None: | ||
| if "gs://" in line: | ||
| with dl.GcsDownloader(self._gcs_client, self.tmpdir) as downloader: | ||
| new_line = self._transform_line(line, downloader) | ||
| self._original_pip(new_line) | ||
| else: | ||
| self._original_pip(line) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,4 @@ | ||
| pytest>=8.0 | ||
| pytest-xdist>=3.0 | ||
| jupyter-client>=8.0 | ||
| nbformat>=5.10 |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| import os | ||
| import sys | ||
| import textwrap | ||
| import subprocess | ||
| import tempfile | ||
| import shutil | ||
| import unittest | ||
|
|
||
| from jupyter_kernel_test import KernelTests | ||
| from jupyter_client.kernelspec import KernelSpecManager | ||
| from jupyter_client.manager import KernelManager | ||
|
|
||
| from google.cloud import storage | ||
|
|
||
|
|
||
| class TestDataprocMagics(KernelTests): | ||
| kernel_name = "python3" # Will be updated in setUp | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| # Override to prevent default kernel from starting. | ||
| # We start a new kernel for each test method. | ||
| pass | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| # Override to prevent default kernel from being shut down. | ||
| pass | ||
|
|
||
| def _get_requirements_file(self): | ||
| bucket_name = os.environ.get("DATAPROC_TEST_BUCKET") | ||
| if not bucket_name: | ||
| self.skipTest("DATAPROC_TEST_BUCKET environment variable not set") | ||
|
|
||
| object_name = "test-magics-requirements.txt" | ||
| storage_client = storage.Client() | ||
| bucket = storage_client.bucket(bucket_name) | ||
| blob = bucket.blob(object_name) | ||
|
|
||
| # Download and verify content | ||
| downloaded_content = blob.download_as_text() | ||
| self.assertEqual(downloaded_content, "humanize==4.14.0\n") | ||
|
|
||
| return bucket_name, object_name | ||
|
|
||
| def setUp(self): | ||
| self.temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-test-") | ||
| venv_dir = os.path.join(self.temp_dir, "venv") | ||
|
|
||
| # Create venv | ||
| subprocess.run( | ||
| [sys.executable, "-m", "venv", venv_dir], | ||
| check=True, | ||
| capture_output=True, | ||
| ) | ||
|
|
||
| # Install deps | ||
| pip_exe = os.path.join(venv_dir, "bin", "pip") | ||
| subprocess.run( | ||
| [pip_exe, "install", "ipykernel", "google-cloud-storage"], | ||
| check=True, | ||
| capture_output=True, | ||
| ) | ||
| subprocess.run( | ||
| [pip_exe, "install", "-e", "."], check=True, capture_output=True | ||
| ) | ||
|
|
||
| # Install kernelspec | ||
| python_exe = os.path.join(venv_dir, "bin", "python") | ||
| self.kernel_name = f"temp-kernel-{os.path.basename(self.temp_dir)}" | ||
|
|
||
| subprocess.run( | ||
| [ | ||
| python_exe, | ||
| "-m", | ||
| "ipykernel", | ||
| "install", | ||
| "--name", | ||
| self.kernel_name, | ||
| "--prefix", | ||
| self.temp_dir, | ||
| ], | ||
| check=True, | ||
| capture_output=True, | ||
| ) | ||
|
|
||
| kernel_dir = os.path.join(self.temp_dir, "share", "jupyter", "kernels") | ||
|
|
||
| # Start kernel | ||
| ksm = KernelSpecManager(kernel_dirs=[kernel_dir]) | ||
| self.km = KernelManager( | ||
| kernel_spec_manager=ksm, kernel_name=self.kernel_name | ||
| ) | ||
| self.km.start_kernel() | ||
|
|
||
| self.kc = self.km.client() | ||
| self.kc.load_connection_file() | ||
| self.kc.start_channels() | ||
| self.kc.wait_for_ready() | ||
|
|
||
| def tearDown(self): | ||
| self.kc.stop_channels() | ||
| self.km.shutdown_kernel() | ||
| shutil.rmtree(self.temp_dir) | ||
|
|
||
| def test_pip_install_from_gcs(self): | ||
| bucket_name, object_name = self._get_requirements_file() | ||
|
|
||
| # Load extension | ||
| reply, output_msgs = self.execute_helper( | ||
| "%load_ext google.cloud.dataproc_magics" | ||
| ) | ||
| # Assert that there are no stream messages (stdout/stderr) | ||
| self.assertFalse( | ||
| any(msg["msg_type"] == "stream" for msg in output_msgs) | ||
| ) | ||
|
|
||
| # Pip install | ||
| install_cmd = f"%pip install -r gs://{bucket_name}/{object_name}" | ||
| self.assert_in_stdout( | ||
| install_cmd, "Successfully installed humanize-4.14.0" | ||
| ) | ||
|
|
||
| # Import and use humanize | ||
| code = textwrap.dedent( | ||
| """ | ||
| import humanize | ||
| print(humanize.intcomma(12345)) | ||
| """ | ||
| ) | ||
| # assert_stdout adds a newline to the expected output if it's not present, | ||
| # because print statements typically add a newline. | ||
| self.assert_stdout(code, "12,345\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check if "install" is also present in case there are some other custom flags that may accept a GCS url?