Skip to content

Commit d52e142

Browse files
committed
feat: Add new dataproc_magics module providing %pip install features
The initial goal of this module is to shadow the built-in `%pip` magic to support installing from `gs://` URLs in addition to local files and `https://` URLs. (Other magic implementations may be added later.) In this commit, we support only the following constructions: - %pip install gs://bucket/mypackage.whl - %pip install -r gs://bucket/myrequirements.txt Recursively handling `gs://` URLs inside requirements files is not yet supported.
1 parent 9228492 commit d52e142

File tree

13 files changed

+1013
-0
lines changed

13 files changed

+1013
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__version__ = "0.1.0"
16+
17+
18+
from google.cloud import storage
19+
from ._internal import magic
20+
21+
22+
_original_pip = None
23+
24+
25+
def load_ipython_extension(ipython):
26+
"""Called by IPython when this module is loaded as an IPython ext."""
27+
global _original_pip
28+
_original_pip = ipython.find_magic("pip")
29+
30+
if _original_pip:
31+
magics = magic.DataprocMagics(
32+
shell=ipython,
33+
original_pip=_original_pip,
34+
gcs_client=storage.Client(),
35+
)
36+
ipython.register_magics(magics)
37+
38+
39+
def unload_ipython_extension(ipython):
40+
"""Called by IPython when this module is unloaded as an IPython ext."""
41+
global _original_pip
42+
if _original_pip:
43+
ipython.register_magic_function(
44+
_original_pip, magic_kind="line", magic_name="pip"
45+
)
46+
_original_pip = None
47+
48+
49+
__all__ = [
50+
"__version__",
51+
"load_ipython_extension",
52+
"unload_ipython_extension",
53+
]

google/cloud/dataproc_magics/_internal/__init__.py

Whitespace-only changes.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for downloading files from GCS."""
16+
17+
import os
18+
import shutil
19+
import tempfile
20+
21+
from google.cloud import storage
22+
23+
24+
class GcsDownloader:
25+
"""Helper for downloading files from GCS.
26+
27+
An instance is a single-use context manager that downloads all its temporary
28+
files to a per-instance temporary directory under its config's tmpdir.
29+
"""
30+
31+
def __init__(self, client: storage.Client, tmpdir: str | None):
32+
self._client = client
33+
self._base_tmpdir = tmpdir
34+
# Per-context tmpdir inside base.
35+
self._tmpdir: str | None = None
36+
37+
def __enter__(self):
38+
if self._tmpdir is not None:
39+
raise RuntimeError(f"{type(self)} has already been entered")
40+
self._tmpdir = tempfile.mkdtemp(dir=self._base_tmpdir)
41+
return self
42+
43+
def __exit__(self, exc_type, exc_val, exc_tb):
44+
if self._tmpdir is None:
45+
raise RuntimeError(f"{type(self)} has not been entered")
46+
print(f"Removing GCS temporary download directory {self._tmpdir}")
47+
try:
48+
shutil.rmtree(self._tmpdir)
49+
except OSError as e:
50+
print(
51+
f"Warning: Failed to remove temporary directory {self._tmpdir}: {e}"
52+
)
53+
self._tmpdir = None
54+
55+
def download(self, url: str):
56+
"""Download the given GCS URL to a temporary file."""
57+
if self._tmpdir is None:
58+
raise RuntimeError("Cannot download outside of a 'with' block")
59+
blob = storage.Blob.from_string(url, self._client)
60+
if blob.name is None:
61+
raise ValueError(f"Couldn't parse blob from URL: {url}")
62+
blob_name = blob.name.rsplit("/", 1)[-1]
63+
tmpfile = os.path.join(self._tmpdir, blob_name)
64+
print(f"Downloading {url} to {tmpfile}")
65+
blob.download_to_filename(tmpfile)
66+
return tmpfile
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dataproc magic implementations."""
16+
17+
from collections.abc import Callable
18+
import shlex
19+
20+
from google.cloud import storage
21+
from IPython.core import magic
22+
import traitlets
23+
24+
from . import dl
25+
26+
27+
@magic.magics_class
28+
class DataprocMagics(magic.Magics):
29+
"""Dataproc magics class."""
30+
31+
tmpdir = traitlets.Unicode(
32+
default_value=None,
33+
allow_none=True,
34+
help="Temporary directory for downloads; defaults to system temp dir",
35+
).tag(config=True)
36+
37+
def __init__(
38+
self,
39+
shell,
40+
original_pip: Callable[[str], None],
41+
gcs_client: storage.Client,
42+
**kwargs,
43+
):
44+
super().__init__(shell, **kwargs)
45+
self._original_pip = original_pip
46+
self._gcs_client = gcs_client
47+
48+
def _transform_line(self, line: str, downloader: dl.GcsDownloader) -> str:
49+
new_args = []
50+
for arg in shlex.split(line):
51+
gcs_url_start = arg.find("gs://")
52+
# gs:// found either at the beginning of an arg, or anywhere in an
53+
# option/value starting with - (short or long form).
54+
if gcs_url_start != -1 and (arg[0] == "-" or gcs_url_start == 0):
55+
prefix = arg[:gcs_url_start]
56+
url = arg[gcs_url_start:]
57+
new_args.append(prefix + downloader.download(url))
58+
else:
59+
new_args.append(arg)
60+
return shlex.join(new_args)
61+
62+
@magic.line_magic
63+
def pip(self, line: str) -> None:
64+
if "gs://" in line:
65+
with dl.GcsDownloader(self._gcs_client, self.tmpdir) as downloader:
66+
new_line = self._transform_line(line, downloader)
67+
self._original_pip(new_line)
68+
else:
69+
self._original_pip(line)

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
google-api-core>=2.19
22
google-cloud-dataproc>=5.18
3+
google-cloud-storage>=3.7.0
34
ipython~=9.1
45
ipywidgets>=8.0.0
56
packaging>=20.0
@@ -9,3 +10,4 @@ setuptools>=72.0
910
sparksql-magic>=0.0.3
1011
tqdm>=4.67
1112
websockets>=14.0
13+
jupyter-kernel-test

requirements-test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
pytest>=8.0
22
pytest-xdist>=3.0
3+
jupyter-client>=8.0
4+
nbformat>=5.10

tests/integration/dataproc_magics/__init__.py

Whitespace-only changes.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import os
17+
import sys
18+
import textwrap
19+
import subprocess
20+
import tempfile
21+
import shutil
22+
import unittest
23+
24+
from jupyter_kernel_test import KernelTests
25+
from jupyter_client.kernelspec import KernelSpecManager
26+
from jupyter_client.manager import KernelManager
27+
28+
from google.cloud import storage
29+
30+
31+
class TestDataprocMagics(KernelTests):
32+
kernel_name = "python3" # Will be updated in setUp
33+
34+
@classmethod
35+
def setUpClass(cls):
36+
# Override to prevent default kernel from starting.
37+
# We start a new kernel for each test method.
38+
pass
39+
40+
@classmethod
41+
def tearDownClass(cls):
42+
# Override to prevent default kernel from being shut down.
43+
pass
44+
45+
def _get_requirements_file(self):
46+
bucket_name = os.environ.get("DATAPROC_TEST_BUCKET")
47+
if not bucket_name:
48+
self.skipTest("DATAPROC_TEST_BUCKET environment variable not set")
49+
50+
object_name = "test-magics-requirements.txt"
51+
storage_client = storage.Client()
52+
bucket = storage_client.bucket(bucket_name)
53+
blob = bucket.blob(object_name)
54+
55+
# Download and verify content
56+
downloaded_content = blob.download_as_text()
57+
self.assertEqual(downloaded_content, "humanize==4.14.0\n")
58+
59+
return bucket_name, object_name
60+
61+
def setUp(self):
62+
self.temp_dir = tempfile.mkdtemp(prefix="dataproc-magics-test-")
63+
venv_dir = os.path.join(self.temp_dir, "venv")
64+
65+
# Create venv
66+
subprocess.run(
67+
[sys.executable, "-m", "venv", venv_dir],
68+
check=True,
69+
capture_output=True,
70+
)
71+
72+
# Install deps
73+
pip_exe = os.path.join(venv_dir, "bin", "pip")
74+
subprocess.run(
75+
[pip_exe, "install", "ipykernel", "google-cloud-storage"],
76+
check=True,
77+
capture_output=True,
78+
)
79+
subprocess.run(
80+
[pip_exe, "install", "-e", "."], check=True, capture_output=True
81+
)
82+
83+
# Install kernelspec
84+
python_exe = os.path.join(venv_dir, "bin", "python")
85+
self.kernel_name = f"temp-kernel-{os.path.basename(self.temp_dir)}"
86+
87+
subprocess.run(
88+
[
89+
python_exe,
90+
"-m",
91+
"ipykernel",
92+
"install",
93+
"--name",
94+
self.kernel_name,
95+
"--prefix",
96+
self.temp_dir,
97+
],
98+
check=True,
99+
capture_output=True,
100+
)
101+
102+
kernel_dir = os.path.join(self.temp_dir, "share", "jupyter", "kernels")
103+
104+
# Start kernel
105+
ksm = KernelSpecManager(kernel_dirs=[kernel_dir])
106+
self.km = KernelManager(
107+
kernel_spec_manager=ksm, kernel_name=self.kernel_name
108+
)
109+
self.km.start_kernel()
110+
111+
self.kc = self.km.client()
112+
self.kc.load_connection_file()
113+
self.kc.start_channels()
114+
self.kc.wait_for_ready()
115+
116+
def tearDown(self):
117+
self.kc.stop_channels()
118+
self.km.shutdown_kernel()
119+
shutil.rmtree(self.temp_dir)
120+
121+
def test_pip_install_from_gcs(self):
122+
bucket_name, object_name = self._get_requirements_file()
123+
124+
# Load extension
125+
reply, output_msgs = self.execute_helper(
126+
"%load_ext google.cloud.dataproc_magics"
127+
)
128+
# Assert that there are no stream messages (stdout/stderr)
129+
self.assertFalse(
130+
any(msg["msg_type"] == "stream" for msg in output_msgs)
131+
)
132+
133+
# Pip install
134+
install_cmd = f"%pip install -r gs://{bucket_name}/{object_name}"
135+
self.assert_in_stdout(
136+
install_cmd, "Successfully installed humanize-4.14.0"
137+
)
138+
139+
# Import and use humanize
140+
code = textwrap.dedent(
141+
"""
142+
import humanize
143+
print(humanize.intcomma(12345))
144+
"""
145+
)
146+
# assert_stdout adds a newline to the expected output if it's not present,
147+
# because print statements typically add a newline.
148+
self.assert_stdout(code, "12,345\n")
149+
150+
151+
if __name__ == "__main__":
152+
unittest.main()

0 commit comments

Comments
 (0)