Skip to content

Commit 488c3f6

Browse files
committed
[SPARK-49776][PYTHON][CONNECT] Support pie plots
### What changes were proposed in this pull request? Support area plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. Area plots are supported as shown below. ```py >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30))] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> fig = df.plot(kind="pie", x="date", y="sales") # df.plot(kind="pie", x="date", y="sales") >>> fig.show() ``` ![newplot (8)](https://github.com/user-attachments/assets/c4078bb7-4d84-4607-bcd7-bdd6fbbf8e28) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48256 from xinrong-meng/plot_pie. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Xinrong Meng <[email protected]>
1 parent 09b7aa6 commit 488c3f6

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,11 @@
812812
"Pipe function `<func_name>` exited with error code <error_code>."
813813
]
814814
},
815+
"PLOT_NOT_NUMERIC_COLUMN": {
816+
"message": [
817+
"Argument <arg_name> must be a numerical column for plotting, got <arg_type>."
818+
]
819+
},
815820
"PYTHON_HASH_SEED_NOT_SET": {
816821
"message": [
817822
"Randomness of hash of string should be disabled via PYTHONHASHSEED."

python/pyspark/sql/plot/core.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from typing import Any, TYPE_CHECKING, Optional, Union
1919
from types import ModuleType
20-
from pyspark.errors import PySparkRuntimeError, PySparkValueError
20+
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
21+
from pyspark.sql.types import NumericType
2122
from pyspark.sql.utils import require_minimum_plotly_version
2223

2324

@@ -97,6 +98,7 @@ class PySparkPlotAccessor:
9798
"bar": PySparkTopNPlotBase().get_top_n,
9899
"barh": PySparkTopNPlotBase().get_top_n,
99100
"line": PySparkSampledPlotBase().get_sampled,
101+
"pie": PySparkTopNPlotBase().get_top_n,
100102
"scatter": PySparkSampledPlotBase().get_sampled,
101103
}
102104
_backends = {} # type: ignore[var-annotated]
@@ -299,3 +301,40 @@ def area(self, x: str, y: str, **kwargs: Any) -> "Figure":
299301
>>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP
300302
"""
301303
return self(kind="area", x=x, y=y, **kwargs)
304+
305+
def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
306+
"""
307+
Generate a pie plot.
308+
309+
A pie plot is a proportional representation of the numerical data in a
310+
column.
311+
312+
Parameters
313+
----------
314+
x : str
315+
Name of column to be used as the category labels for the pie plot.
316+
y : str
317+
Name of the column to plot.
318+
**kwargs
319+
Additional keyword arguments.
320+
321+
Returns
322+
-------
323+
:class:`plotly.graph_objs.Figure`
324+
325+
Examples
326+
--------
327+
"""
328+
schema = self.data.schema
329+
330+
# Check if 'y' is a numerical column
331+
y_field = schema[y] if y in schema.names else None
332+
if y_field is None or not isinstance(y_field.dataType, NumericType):
333+
raise PySparkTypeError(
334+
errorClass="PLOT_NOT_NUMERIC_COLUMN",
335+
messageParameters={
336+
"arg_name": "y",
337+
"arg_type": str(y_field.dataType) if y_field else "None",
338+
},
339+
)
340+
return self(kind="pie", x=x, y=y, **kwargs)

python/pyspark/sql/plot/plotly.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,19 @@
2727
def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
2828
import plotly
2929

30+
if kind == "pie":
31+
return plot_pie(data, **kwargs)
32+
3033
return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)
34+
35+
36+
def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
37+
# TODO(SPARK-49530): Support pie subplots with plotly backend
38+
from plotly import express
39+
40+
pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
41+
x = kwargs.pop("x", None)
42+
y = kwargs.pop("y", None)
43+
fig = express.pie(pdf, values=y, names=x, **kwargs)
44+
45+
return fig

python/pyspark/sql/tests/plot/test_frame_plot_plotly.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from datetime import datetime
2020

2121
import pyspark.sql.plot # noqa: F401
22+
from pyspark.errors import PySparkTypeError
2223
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
2324

2425

@@ -64,6 +65,11 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=
6465
self.assertEqual(fig_data["type"], "scatter")
6566
self.assertEqual(fig_data["orientation"], "v")
6667
self.assertEqual(fig_data["mode"], "lines")
68+
elif kind == "pie":
69+
self.assertEqual(fig_data["type"], "pie")
70+
self.assertEqual(list(fig_data["labels"]), expected_x)
71+
self.assertEqual(list(fig_data["values"]), expected_y)
72+
return
6773

6874
self.assertEqual(fig_data["xaxis"], "x")
6975
self.assertEqual(list(fig_data["x"]), expected_x)
@@ -133,6 +139,25 @@ def test_area_plot(self):
133139
self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups")
134140
self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits")
135141

142+
def test_pie_plot(self):
143+
fig = self.sdf3.plot(kind="pie", x="date", y="sales")
144+
expected_x = [
145+
datetime(2018, 1, 31, 0, 0),
146+
datetime(2018, 2, 28, 0, 0),
147+
datetime(2018, 3, 31, 0, 0),
148+
datetime(2018, 4, 30, 0, 0),
149+
]
150+
self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9])
151+
152+
# y is not a numerical column
153+
with self.assertRaises(PySparkTypeError) as pe:
154+
self.sdf.plot.pie(x="int_val", y="category")
155+
self.check_error(
156+
exception=pe.exception,
157+
errorClass="PLOT_NOT_NUMERIC_COLUMN",
158+
messageParameters={"arg_name": "y", "arg_type": "StringType()"},
159+
)
160+
136161

137162
class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
138163
pass

0 commit comments

Comments
 (0)