Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions google/cloud/dataproc_magics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.0"


from google.cloud import storage
from ._internal import magic


_original_pip = None


def load_ipython_extension(ipython):
"""Called by IPython when this module is loaded as an IPython ext."""
global _original_pip
_original_pip = ipython.find_magic("pip")

if _original_pip:
magics = magic.DataprocMagics(
shell=ipython,
original_pip=_original_pip,
gcs_client=storage.Client(),
)
ipython.register_magics(magics)


def unload_ipython_extension(ipython):
"""Called by IPython when this module is unloaded as an IPython ext."""
global _original_pip
if _original_pip:
ipython.register_magic_function(
_original_pip, magic_kind="line", magic_name="pip"
)
_original_pip = None


__all__ = [
"__version__",
"load_ipython_extension",
"unload_ipython_extension",
]
Empty file.
66 changes: 66 additions & 0 deletions google/cloud/dataproc_magics/_internal/dl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for downloading files from GCS."""

import os
import shutil
import tempfile

from google.cloud import storage


class GcsDownloader:
"""Helper for downloading files from GCS.

An instance is a single-use context manager that downloads all its temporary
files to a per-instance temporary directory under its config's tmpdir.
"""

def __init__(self, client: storage.Client, tmpdir: str | None):
self._client = client
self._base_tmpdir = tmpdir
# Per-context tmpdir inside base.
self._tmpdir: str | None = None

def __enter__(self):
if self._tmpdir is not None:
raise RuntimeError(f"{type(self)} has already been entered")
self._tmpdir = tempfile.mkdtemp(dir=self._base_tmpdir)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self._tmpdir is None:
raise RuntimeError(f"{type(self)} has not been entered")
print(f"Removing GCS temporary download directory {self._tmpdir}")
try:
shutil.rmtree(self._tmpdir)
except OSError as e:
print(
f"Warning: Failed to remove temporary directory {self._tmpdir}: {e}"
)
self._tmpdir = None

def download(self, url: str):
"""Download the given GCS URL to a temporary file."""
if self._tmpdir is None:
raise RuntimeError("Cannot download outside of a 'with' block")
blob = storage.Blob.from_string(url, self._client)
if blob.name is None:
raise ValueError(f"Couldn't parse blob from URL: {url}")
blob_name = blob.name.rsplit("/", 1)[-1]
tmpfile = os.path.join(self._tmpdir, blob_name)
print(f"Downloading {url} to {tmpfile}")
blob.download_to_filename(tmpfile)
return tmpfile
69 changes: 69 additions & 0 deletions google/cloud/dataproc_magics/_internal/magic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataproc magic implementations."""

from collections.abc import Callable
import shlex

from google.cloud import storage
from IPython.core import magic
import traitlets

from . import dl


@magic.magics_class
class DataprocMagics(magic.Magics):
"""Dataproc magics class."""

tmpdir = traitlets.Unicode(
default_value=None,
allow_none=True,
help="Temporary directory for downloads; defaults to system temp dir",
).tag(config=True)

def __init__(
self,
shell,
original_pip: Callable[[str], None],
gcs_client: storage.Client,
**kwargs,
):
super().__init__(shell, **kwargs)
self._original_pip = original_pip
self._gcs_client = gcs_client

def _transform_line(self, line: str, downloader: dl.GcsDownloader) -> str:
new_args = []
for arg in shlex.split(line):
gcs_url_start = arg.find("gs://")
# gs:// found either at the beginning of an arg, or anywhere in an
# option/value starting with - (short or long form).
if gcs_url_start != -1 and (arg[0] == "-" or gcs_url_start == 0):
prefix = arg[:gcs_url_start]
url = arg[gcs_url_start:]
new_args.append(prefix + downloader.download(url))
else:
new_args.append(arg)
return shlex.join(new_args)

@magic.line_magic
def pip(self, line: str) -> None:
if "gs://" in line:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check if "install" is also present in case there are some other custom flags that may accept a GCS url?

with dl.GcsDownloader(self._gcs_client, self.tmpdir) as downloader:
new_line = self._transform_line(line, downloader)
self._original_pip(new_line)
else:
self._original_pip(line)
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
google-api-core>=2.19
google-cloud-dataproc>=5.18
google-cloud-storage>=3.7.0
ipython~=9.1
ipywidgets>=8.0.0
packaging>=20.0
Expand All @@ -9,3 +10,4 @@ setuptools>=72.0
sparksql-magic>=0.0.3
tqdm>=4.67
websockets>=14.0
jupyter-kernel-test
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pytest>=8.0
pytest-xdist>=3.0
jupyter-client>=8.0
nbformat>=5.10
Empty file.
152 changes: 152 additions & 0 deletions tests/integration/dataproc_magics/test_magics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import sys
import textwrap
import subprocess
import tempfile
import shutil
import unittest

from jupyter_kernel_test import KernelTests
from jupyter_client.kernelspec import KernelSpecManager
from jupyter_client.manager import KernelManager

from google.cloud import storage


class TestDataprocMagics(KernelTests):
kernel_name = "python3" # Will be updated in setUp

@classmethod
def setUpClass(cls):
# Override to prevent default kernel from starting.
# We start a new kernel for each test method.
pass

@classmethod
def tearDownClass(cls):
# Override to prevent default kernel from being shut down.
pass

def _get_requirements_file(self):
bucket_name = os.environ.get("DATAPROC_TEST_BUCKET")
if not bucket_name:
self.skipTest("DATAPROC_TEST_BUCKET environment variable not set")

object_name = "test-magics-requirements.txt"
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(object_name)

# Download and verify content
downloaded_content = blob.download_as_text()
self.assertEqual(downloaded_content, "humanize==4.14.0\n")

return bucket_name, object_name

def setUp(self):
self.temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-test-")
venv_dir = os.path.join(self.temp_dir, "venv")

# Create venv
subprocess.run(
[sys.executable, "-m", "venv", venv_dir],
check=True,
capture_output=True,
)

# Install deps
pip_exe = os.path.join(venv_dir, "bin", "pip")
subprocess.run(
[pip_exe, "install", "ipykernel", "google-cloud-storage"],
check=True,
capture_output=True,
)
subprocess.run(
[pip_exe, "install", "-e", "."], check=True, capture_output=True
)

# Install kernelspec
python_exe = os.path.join(venv_dir, "bin", "python")
self.kernel_name = f"temp-kernel-{os.path.basename(self.temp_dir)}"

subprocess.run(
[
python_exe,
"-m",
"ipykernel",
"install",
"--name",
self.kernel_name,
"--prefix",
self.temp_dir,
],
check=True,
capture_output=True,
)

kernel_dir = os.path.join(self.temp_dir, "share", "jupyter", "kernels")

# Start kernel
ksm = KernelSpecManager(kernel_dirs=[kernel_dir])
self.km = KernelManager(
kernel_spec_manager=ksm, kernel_name=self.kernel_name
)
self.km.start_kernel()

self.kc = self.km.client()
self.kc.load_connection_file()
self.kc.start_channels()
self.kc.wait_for_ready()

def tearDown(self):
self.kc.stop_channels()
self.km.shutdown_kernel()
shutil.rmtree(self.temp_dir)

def test_pip_install_from_gcs(self):
bucket_name, object_name = self._get_requirements_file()

# Load extension
reply, output_msgs = self.execute_helper(
"%load_ext google.cloud.dataproc_magics"
)
# Assert that there are no stream messages (stdout/stderr)
self.assertFalse(
any(msg["msg_type"] == "stream" for msg in output_msgs)
)

# Pip install
install_cmd = f"%pip install -r gs://{bucket_name}/{object_name}"
self.assert_in_stdout(
install_cmd, "Successfully installed humanize-4.14.0"
)

# Import and use humanize
code = textwrap.dedent(
"""
import humanize
print(humanize.intcomma(12345))
"""
)
# assert_stdout adds a newline to the expected output if it's not present,
# because print statements typically add a newline.
self.assert_stdout(code, "12,345\n")


if __name__ == "__main__":
unittest.main()
Loading
Loading