Skip to content

Commit a30a3fd

Browse files
[SPARK-49530][PYTHON] Support pie subplots in pyspark plotting
### What changes were proposed in this pull request? Support pie subplots in pyspark plotting. ### Why are the changes needed? API parity with pandas.DataFrame.plot.pie, see [here](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.plot.pie.html) ### Does this PR introduce _any_ user-facing change? Pie subplots are supported as shown below: ```py >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30)), ... ] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> fig = df.plot(kind="pie", x="date", subplots=True) >>> fig.show() ``` ![newplot (2)](https://github.com/user-attachments/assets/2b019c6a-82da-4c12-b1ff-096786801f56) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49268 from xinrong-meng/pie_subplot. Lead-authored-by: Xinrong Meng <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 7cd5c4a commit a30a3fd

File tree

4 files changed

+68
-23
lines changed

4 files changed

+68
-23
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,11 @@
11031103
"Function `<func_name>` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
11041104
]
11051105
},
1106+
"UNSUPPORTED_PIE_PLOT_PARAM": {
1107+
"message": [
1108+
"Pie plot requires either a `y` column or `subplots=True`."
1109+
]
1110+
},
11061111
"UNSUPPORTED_PLOT_BACKEND": {
11071112
"message": [
11081113
"`<backend>` is not supported, it should be one of the values from <supported_backends>"

python/pyspark/sql/plot/core.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919

2020
from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
2121
from types import ModuleType
22-
from pyspark.errors import PySparkTypeError, PySparkValueError
22+
from pyspark.errors import PySparkValueError
2323
from pyspark.sql import Column, functions as F
2424
from pyspark.sql.internal import InternalFunction as SF
2525
from pyspark.sql.pandas.utils import require_minimum_pandas_version
26-
from pyspark.sql.types import NumericType
2726
from pyspark.sql.utils import NumpyHelper, require_minimum_plotly_version
2827

2928
if TYPE_CHECKING:
@@ -295,7 +294,7 @@ def area(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
295294
"""
296295
return self(kind="area", x=x, y=y, **kwargs)
297296

298-
def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
297+
def pie(self, x: str, y: Optional[str], **kwargs: Any) -> "Figure":
299298
"""
300299
Generate a pie plot.
301300
@@ -306,8 +305,8 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
306305
----------
307306
x : str
308307
Name of column to be used as the category labels for the pie plot.
309-
y : str
310-
Name of the column to plot.
308+
y : str, optional
309+
Name of the column to plot. If not provided, `subplots=True` must be passed at `kwargs`.
311310
**kwargs
312311
Additional keyword arguments.
313312
@@ -327,19 +326,8 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
327326
>>> columns = ["sales", "signups", "visits", "date"]
328327
>>> df = spark.createDataFrame(data, columns)
329328
>>> df.plot.pie(x='date', y='sales') # doctest: +SKIP
329+
>>> df.plot.pie(x='date', subplots=True) # doctest: +SKIP
330330
"""
331-
schema = self.data.schema
332-
333-
# Check if 'y' is a numerical column
334-
y_field = schema[y] if y in schema.names else None
335-
if y_field is None or not isinstance(y_field.dataType, NumericType):
336-
raise PySparkTypeError(
337-
errorClass="PLOT_NOT_NUMERIC_COLUMN_ARGUMENT",
338-
messageParameters={
339-
"arg_name": "y",
340-
"arg_type": str(y_field.dataType.__class__.__name__) if y_field else "None",
341-
},
342-
)
343331
return self(kind="pie", x=x, y=y, **kwargs)
344332

345333
def box(self, column: Optional[Union[str, List[str]]] = None, **kwargs: Any) -> "Figure":

python/pyspark/sql/plot/plotly.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,34 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
4848

4949

5050
def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
51-
# TODO(SPARK-49530): Support pie subplots with plotly backend
5251
from plotly import express
5352

5453
pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
5554
x = kwargs.pop("x", None)
5655
y = kwargs.pop("y", None)
57-
fig = express.pie(pdf, values=y, names=x, **kwargs)
56+
subplots = kwargs.pop("subplots", False)
57+
if y is None and not subplots:
58+
raise PySparkValueError(errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={})
59+
60+
numeric_ys = process_column_param(y, data)
61+
62+
if subplots:
63+
# One pie chart per numeric column
64+
from plotly.subplots import make_subplots
65+
66+
fig = make_subplots(
67+
rows=1,
68+
cols=len(numeric_ys),
69+
# To accommodate domain-based trace - pie chart
70+
specs=[[{"type": "domain"}] * len(numeric_ys)],
71+
)
72+
for i, y_col in enumerate(numeric_ys):
73+
subplot_fig = express.pie(pdf, values=y_col, names=x, **kwargs)
74+
fig.add_trace(
75+
subplot_fig.data[0], row=1, col=i + 1
76+
) # A single pie chart has only one trace
77+
else:
78+
fig = express.pie(pdf, values=numeric_ys[0], names=x, **kwargs)
5879

5980
return fig
6081

python/pyspark/sql/tests/plot/test_frame_plot_plotly.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,29 +301,60 @@ def test_area_plot(self):
301301
self._check_fig_data(fig["data"][2], **expected_fig_data)
302302

303303
def test_pie_plot(self):
304+
# single column as 'y'
304305
fig = self.sdf3.plot(kind="pie", x="date", y="sales")
305306
expected_x = [
306307
datetime(2018, 1, 31, 0, 0),
307308
datetime(2018, 2, 28, 0, 0),
308309
datetime(2018, 3, 31, 0, 0),
309310
datetime(2018, 4, 30, 0, 0),
310311
]
311-
expected_fig_data = {
312+
expected_fig_data_sales = {
312313
"name": "",
313314
"labels": expected_x,
314315
"values": [3, 2, 3, 9],
315316
"type": "pie",
316317
}
317-
self._check_fig_data(fig["data"][0], **expected_fig_data)
318+
self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
319+
320+
# all numeric columns as 'y'
321+
expected_fig_data_signups = {
322+
"name": "",
323+
"labels": expected_x,
324+
"values": [5, 5, 6, 12],
325+
"type": "pie",
326+
}
327+
expected_fig_data_visits = {
328+
"name": "",
329+
"labels": expected_x,
330+
"values": [20, 42, 28, 62],
331+
"type": "pie",
332+
}
333+
fig = self.sdf3.plot(kind="pie", x="date", subplots=True)
334+
self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
335+
self._check_fig_data(fig["data"][1], **expected_fig_data_signups)
336+
self._check_fig_data(fig["data"][2], **expected_fig_data_visits)
337+
338+
# not specify subplots
339+
with self.assertRaises(PySparkValueError) as pe:
340+
self.sdf3.plot(kind="pie", x="date")
341+
342+
self.check_error(
343+
exception=pe.exception, errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={}
344+
)
318345

319346
# y is not a numerical column
320347
with self.assertRaises(PySparkTypeError) as pe:
321348
self.sdf.plot.pie(x="int_val", y="category")
322349

323350
self.check_error(
324351
exception=pe.exception,
325-
errorClass="PLOT_NOT_NUMERIC_COLUMN_ARGUMENT",
326-
messageParameters={"arg_name": "y", "arg_type": "StringType"},
352+
errorClass="PLOT_INVALID_TYPE_COLUMN",
353+
messageParameters={
354+
"col_name": "category",
355+
"valid_types": "NumericType",
356+
"col_type": "StringType",
357+
},
327358
)
328359

329360
def test_box_plot(self):

0 commit comments

Comments
 (0)