Skip to content

Commit 274f013

Browse files
authored
feat: add support for sparksql magic (#137)
1 parent 30e22bb commit 274f013

File tree

5 files changed

+178
-2
lines changed

5 files changed

+178
-2
lines changed

DEVELOPING.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,28 @@ env \
3939
pytest --tb=auto -v
4040
```
4141

42+
## Testing with Magic Support
43+
44+
To run tests with magic functionality, install the required dependencies manually:
45+
46+
```sh
47+
pip install .
48+
pip install IPython sparksql-magic
49+
```
50+
51+
Then run tests as normal. Any magic-related tests will automatically detect and use the available dependencies.
52+
53+
## Testing without Magic Support
54+
55+
To run tests without the magic dependencies, simply install the base package:
56+
57+
```sh
58+
pip install .
59+
pytest
60+
```
61+
62+
Tests that require magic functionality will be automatically skipped if the dependencies are not available.
63+
4264
The integration tests in particular can take a while to run. To speed up the
4365
testing cycle, you can run them in parallel. You can do so using the `xdist`
4466
plugin by setting the `-n` flag to the number of parallel runners you want to

README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,53 @@ environment variables:
5454
spark = DataprocSparkSession.builder.dataprocSessionConfig(session_config).getOrCreate()
5555
```
5656

57+
### Using Spark SQL Magic Commands (Jupyter Notebooks)
58+
59+
The package supports the [sparksql-magic](https://github.com/cryeo/sparksql-magic) library for executing Spark SQL queries directly in Jupyter notebooks.
60+
61+
**Installation**: To use magic commands, install the required dependencies manually:
62+
```bash
63+
pip install dataproc-spark-connect
64+
pip install IPython sparksql-magic
65+
```
66+
67+
1. Load the magic extension:
68+
```python
69+
%load_ext sparksql_magic
70+
```
71+
72+
2. Configure default settings (optional):
73+
```python
74+
%config SparkSql.limit=20
75+
```
76+
77+
3. Execute SQL queries:
78+
```python
79+
%%sparksql
80+
SELECT * FROM your_table
81+
```
82+
83+
4. Advanced usage with options:
84+
```python
85+
# Cache results and create a view
86+
%%sparksql --cache --view result_view df
87+
SELECT * FROM your_table WHERE condition = true
88+
```
89+
90+
Available options:
91+
- `--cache` / `-c`: Cache the DataFrame
92+
- `--eager` / `-e`: Cache with eager loading
93+
- `--view VIEW` / `-v VIEW`: Create a temporary view
94+
- `--limit N` / `-l N`: Override default row display limit
95+
- `variable_name`: Store result in a variable
96+
97+
See [sparksql-magic](https://github.com/cryeo/sparksql-magic) for more examples.
98+
99+
**Note**: Magic commands are optional. If you only need basic DataprocSparkSession functionality without Jupyter magic support, install only the base package:
100+
```bash
101+
pip install dataproc-spark-connect
102+
```
103+
57104
## Developing
58105

59106
For development instructions see [guide](DEVELOPING.md).

google/cloud/dataproc_spark_connect/session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,13 @@ def getOrCreate(self) -> "DataprocSparkSession":
559559
session = self._get_exiting_active_session()
560560
if session is None:
561561
session = self.__create()
562+
563+
# Register this session as the instantiated SparkSession for compatibility
564+
# with tools and libraries that expect SparkSession._instantiatedSession
565+
from pyspark.sql import SparkSession as PySparkSQLSession
566+
567+
PySparkSQLSession._instantiatedSession = session
568+
562569
return session
563570

564571
def _handle_custom_session_id(self):
@@ -1162,6 +1169,20 @@ def stop(self) -> None:
11621169
)
11631170

11641171
self._remove_stopped_session_from_file()
1172+
1173+
# Clean up SparkSession._instantiatedSession if it points to this session
1174+
try:
1175+
from pyspark.sql import SparkSession as PySparkSQLSession
1176+
1177+
if PySparkSQLSession._instantiatedSession is self:
1178+
PySparkSQLSession._instantiatedSession = None
1179+
logger.debug(
1180+
"Cleared SparkSession._instantiatedSession reference"
1181+
)
1182+
except (ImportError, AttributeError):
1183+
# PySpark not available or _instantiatedSession doesn't exist
1184+
pass
1185+
11651186
DataprocSparkSession._active_s8s_session_uuid = None
11661187
DataprocSparkSession._active_s8s_session_id = None
11671188
DataprocSparkSession._active_session_uses_custom_id = False

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
google-api-core>=2.19
22
google-cloud-dataproc>=5.18
33
ipython~=9.1
4+
ipywidgets>=8.0.0
45
packaging>=20.0
56
pyink~=24.0
67
pyspark[connect]~=4.0.0
78
setuptools>=72.0
9+
sparksql-magic>=0.0.3
810
tqdm>=4.67
911
websockets>=14.0

tests/integration/test_session.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def os_environment(auth_type, image_version, test_project, test_region):
7979
)
8080
os.environ["DATAPROC_SPARK_CONNECT_AUTH_TYPE"] = auth_type
8181
if auth_type == "END_USER_CREDENTIALS":
82-
os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT")
82+
os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT", None)
8383
# Add SSL certificate fix
8484
os.environ["SSL_CERT_FILE"] = certifi.where()
8585
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
@@ -113,7 +113,11 @@ def session_template_controller_client(test_client_options):
113113

114114
@pytest.fixture
115115
def connect_session(test_project, test_region, os_environment):
116-
return DataprocSparkSession.builder.getOrCreate()
116+
return (
117+
DataprocSparkSession.builder.projectId(test_project)
118+
.location(test_region)
119+
.getOrCreate()
120+
)
117121

118122

119123
@pytest.fixture
@@ -537,3 +541,83 @@ def test_session_id_validation_in_integration(
537541

538542
# Should not raise an exception
539543
assert builder._custom_session_id == valid_id
544+
545+
546+
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
547+
def test_sparksql_magic_library_available(connect_session):
548+
"""Test that sparksql-magic library can be imported and loaded."""
549+
pytest.importorskip(
550+
"IPython", reason="IPython not available (install with magic extra)"
551+
)
552+
pytest.importorskip(
553+
"sparksql_magic",
554+
reason="sparksql-magic not available (install with magic extra)",
555+
)
556+
557+
from IPython.terminal.interactiveshell import TerminalInteractiveShell
558+
559+
# Create real IPython shell
560+
shell = TerminalInteractiveShell.instance()
561+
shell.user_ns = {"spark": connect_session}
562+
563+
# Test that sparksql_magic can be loaded (this verifies the dependency works)
564+
try:
565+
shell.run_line_magic("load_ext", "sparksql_magic")
566+
magic_loaded = True
567+
except Exception as e:
568+
magic_loaded = False
569+
print(f"Failed to load sparksql_magic: {e}")
570+
571+
assert magic_loaded, "sparksql_magic should be available as a dependency"
572+
573+
# Test that DataprocSparkSession can execute SQL (ensuring basic compatibility)
574+
result = connect_session.sql("SELECT 'integration_test' as test_column")
575+
data = result.collect()
576+
assert len(data) == 1
577+
assert data[0]["test_column"] == "integration_test"
578+
579+
580+
@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
581+
def test_sparksql_magic_with_dataproc_session(connect_session):
582+
"""Test that sparksql-magic works with registered DataprocSparkSession."""
583+
pytest.importorskip(
584+
"IPython", reason="IPython not available (install with magic extra)"
585+
)
586+
pytest.importorskip(
587+
"sparksql_magic",
588+
reason="sparksql-magic not available (install with magic extra)",
589+
)
590+
591+
from IPython.terminal.interactiveshell import TerminalInteractiveShell
592+
593+
# Create real IPython shell (DataprocSparkSession is already registered globally)
594+
shell = TerminalInteractiveShell.instance()
595+
596+
# Load the sparksql_magic extension
597+
shell.run_line_magic("load_ext", "sparksql_magic")
598+
599+
# Test sparksql magic with SQL expressions (no variable capture to avoid namespace issues)
600+
shell.run_cell_magic(
601+
"sparksql",
602+
"result_df",
603+
"""
604+
SELECT
605+
10 * 5 as multiplication,
606+
SQRT(16) as square_root,
607+
CONCAT('Dataproc', '-', 'Spark') as joined_string
608+
""",
609+
)
610+
611+
# Verify the result is captured in the namespace
612+
assert "result_df" in shell.user_ns
613+
df = shell.user_ns["result_df"]
614+
assert df is not None
615+
616+
# Verify the computed values
617+
data = df.collect()
618+
assert len(data) == 1
619+
row = data[0]
620+
621+
assert row["multiplication"] == 50
622+
assert row["square_root"] == 4.0
623+
assert row["joined_string"] == "Dataproc-Spark"

0 commit comments

Comments
 (0)