Skip to content

Commit 0bd8507

Browse files
committed
fix: use PyCapsule Interface instead of Dataframe Interchange Protocol
1 parent b4e5f8d commit 0bd8507

File tree

8 files changed

+66
-66
lines changed

8 files changed

+66
-66
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dev = [
4242
"mypy",
4343
"pandas-stubs",
4444
"pre-commit",
45+
"pyarrow",
4546
"flit",
4647
]
4748
docs = [

seaborn/_core/data.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from collections.abc import Mapping, Sized
77
from typing import cast
8-
import warnings
98

109
import pandas as pd
1110
from pandas import DataFrame
@@ -269,9 +268,9 @@ def _assign_variables(
269268

270269
def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
271270
"""Convert the data source object to a common union representation."""
272-
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
271+
if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"):
273272
# Check for pd.DataFrame inheritance could be removed once
274-
# minimal pandas version supports dataframe interchange (1.5.0).
273+
# minimal pandas version supports PyCapsule Interface (2.2).
275274
data = convert_dataframe_to_pandas(data)
276275
elif data is not None and not isinstance(data, Mapping):
277276
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
@@ -285,35 +284,32 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
285284
if isinstance(data, pd.DataFrame):
286285
return data
287286

288-
if not hasattr(pd.api, "interchange"):
289-
msg = (
290-
"Support for non-pandas DataFrame objects requires a version of pandas "
291-
"that implements the DataFrame interchange protocol. Please upgrade "
292-
"your pandas version or coerce your data to pandas before passing "
293-
"it to seaborn."
294-
)
295-
raise TypeError(msg)
296-
297-
if _version_predates(pd, "2.0.2"):
298-
msg = (
299-
"DataFrame interchange with pandas<2.0.2 has some known issues. "
300-
f"You are using pandas {pd.__version__}. "
301-
"Continuing, but it is recommended to carefully inspect the results and to "
302-
"consider upgrading."
303-
)
304-
warnings.warn(msg, stacklevel=2)
305-
306-
try:
307-
# This is going to convert all columns in the input dataframe, even though
308-
# we may only need one or two of them. It would be more efficient to select
309-
# the columns that are going to be used in the plot prior to interchange.
310-
# Solving that in general is a hard problem, especially with the objects
311-
# interface where variables passed in Plot() may only be referenced later
312-
# in Plot.add(). But noting here in case this seems to be a bottleneck.
313-
return pd.api.interchange.from_dataframe(data)
314-
except Exception as err:
315-
msg = (
316-
"Encountered an exception when converting data source "
317-
"to a pandas DataFrame. See traceback above for details."
318-
)
319-
raise RuntimeError(msg) from err
287+
if hasattr(data, '__arrow_c_stream__'):
288+
try:
289+
import pyarrow
290+
except ImportError as err:
291+
msg = "PyArrow is required for non-pandas Dataframe support."
292+
raise RuntimeError(msg) from err
293+
if _version_predates(pyarrow, '14.0.0'):
294+
msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support."
295+
raise RuntimeError(msg)
296+
try:
297+
# This is going to convert all columns in the input dataframe, even though
298+
# we may only need one or two of them. It would be more efficient to select
299+
# the columns that are going to be used in the plot prior to interchange.
300+
# Solving that in general is a hard problem, especially with the objects
301+
# interface where variables passed in Plot() may only be referenced later
302+
# in Plot.add(). But noting here in case this seems to be a bottleneck.
303+
return pyarrow.table(data).to_pandas()
304+
except Exception as err:
305+
msg = (
306+
"Encountered an exception when converting data source "
307+
"to a pandas DataFrame. See traceback above for details."
308+
)
309+
raise RuntimeError(msg) from err
310+
311+
msg = (
312+
"Expected object which implements '__arrow_c_stream__' from the "
313+
f"PyCapsule Interface, got: {type(data)}"
314+
)
315+
raise TypeError(msg)

seaborn/_core/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _resolve_positionals(
349349

350350
if (
351351
isinstance(args[0], (abc.Mapping, pd.DataFrame))
352-
or hasattr(args[0], "__dataframe__")
352+
or hasattr(args[0], "__arrow_c_stream__")
353353
):
354354
if data is not None:
355355
raise TypeError("`data` given by both name and position.")

seaborn/_core/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
VariableSpec = Union[ColumnName, Vector, None]
1818
VariableSpecList = Union[List[VariableSpec], Index, None]
1919

20-
# A DataSource can be an object implementing __dataframe__, or a Mapping
20+
# A DataSource can be an object implementing __arrow_c_stream__, or a Mapping
2121
# (and is optional in all contexts where it is used).
22-
# I don't think there's an abc for "has __dataframe__", so we type as object
22+
# I don't think there's an abc for "has __arrow_c_stream__", so we type as object
2323
# but keep the (slightly odd) Union alias for better user-facing annotations.
2424
DataSource = Union[object, Mapping, None]
2525

tests/_core/test_data.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import numpy as np
33
import pandas as pd
4+
from seaborn.external.version import Version
45

56
import pytest
67
from numpy.testing import assert_array_equal
@@ -404,11 +405,11 @@ def test_bad_type(self, flat_list):
404405
with pytest.raises(TypeError, match=err):
405406
PlotData(flat_list, {})
406407

407-
@pytest.mark.skipif(
408-
condition=not hasattr(pd.api, "interchange"),
409-
reason="Tests behavior assuming support for dataframe interchange"
410-
)
411408
def test_data_interchange(self, mock_long_df, long_df):
409+
pytest.importorskip(
410+
'pyarrow', '14.0',
411+
reason="Tests behavior assuming support for PyCapsule Interface"
412+
)
412413

413414
variables = {"x": "x", "y": "z", "color": "a"}
414415
p = PlotData(mock_long_df, variables)
@@ -419,21 +420,22 @@ def test_data_interchange(self, mock_long_df, long_df):
419420
for var, col in variables.items():
420421
assert_vector_equal(p.frame[var], long_df[col])
421422

422-
@pytest.mark.skipif(
423-
condition=not hasattr(pd.api, "interchange"),
424-
reason="Tests behavior assuming support for dataframe interchange"
425-
)
426423
def test_data_interchange_failure(self, mock_long_df):
424+
pytest.importorskip(
425+
'pyarrow', '14.0',
426+
reason="Tests behavior assuming support for PyCapsule Interface"
427+
)
427428

428-
mock_long_df._data = None # Break __dataframe__()
429+
mock_long_df.__arrow_c_stream__ = lambda _x: 1 / 0 # Break __arrow_c_stream__()
429430
with pytest.raises(RuntimeError, match="Encountered an exception"):
430431
PlotData(mock_long_df, {"x": "x"})
431432

432-
@pytest.mark.skipif(
433-
condition=hasattr(pd.api, "interchange"),
434-
reason="Tests graceful failure without support for dataframe interchange"
435-
)
436433
def test_data_interchange_support_test(self, mock_long_df):
434+
pyarrow = pytest.importorskip('pyarrow')
435+
if Version(pyarrow.__version__) >= Version('14.0.0'):
436+
pytest.skip(
437+
reason="Tests graceful failure without support for PyCapsule Interface"
438+
)
437439

438440
with pytest.raises(TypeError, match="Support for non-pandas DataFrame"):
439441
PlotData(mock_long_df, {"x": "x"})

tests/_core/test_plot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ def test_positional_x(self, long_df):
170170
assert p._data.source_data is None
171171
assert list(p._data.source_vars) == ["x"]
172172

173-
@pytest.mark.skipif(
174-
condition=not hasattr(pd.api, "interchange"),
175-
reason="Tests behavior assuming support for dataframe interchange"
176-
)
177173
def test_positional_interchangeable_dataframe(self, mock_long_df, long_df):
174+
pytest.importorskip(
175+
'pyarrow', '14.0',
176+
reason="Tests behavior assuming support for PyCapsule Interface"
177+
)
178178

179179
p = Plot(mock_long_df, x="x")
180180
assert_frame_equal(p._data.source_data, long_df)

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,12 @@ class MockInterchangeableDataFrame:
188188
def __init__(self, data):
189189
self._data = data
190190

191-
def __dataframe__(self, *args, **kwargs):
192-
return self._data.__dataframe__(*args, **kwargs)
191+
def __arrow_c_stream__(self, *args, **kwargs):
192+
return self._data.__arrow_c_stream__()
193193

194194

195195
@pytest.fixture
196196
def mock_long_df(long_df):
197+
import pyarrow
197198

198-
return MockInterchangeableDataFrame(long_df)
199+
return MockInterchangeableDataFrame(pyarrow.Table.from_pandas(long_df))

tests/test_axisgrid.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -708,11 +708,11 @@ def test_tick_params(self):
708708
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
709709
assert tick.get_pad() == pad
710710

711-
@pytest.mark.skipif(
712-
condition=not hasattr(pd.api, "interchange"),
713-
reason="Tests behavior assuming support for dataframe interchange"
714-
)
715711
def test_data_interchange(self, mock_long_df, long_df):
712+
pytest.importorskip(
713+
'pyarrow', '14.0',
714+
reason="Tests behavior assuming support for PyCapsule Interface"
715+
)
716716

717717
g = ag.FacetGrid(mock_long_df, col="a", row="b")
718718
g.map(scatterplot, "x", "y")
@@ -1477,11 +1477,11 @@ def test_tick_params(self):
14771477
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
14781478
assert tick.get_pad() == pad
14791479

1480-
@pytest.mark.skipif(
1481-
condition=not hasattr(pd.api, "interchange"),
1482-
reason="Tests behavior assuming support for dataframe interchange"
1483-
)
14841480
def test_data_interchange(self, mock_long_df, long_df):
1481+
pytest.importorskip(
1482+
'pyarrow', '14.0',
1483+
reason="Tests behavior assuming support for PyCapsule Interface"
1484+
)
14851485

14861486
g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a")
14871487
g.map(scatterplot)

0 commit comments

Comments
 (0)