Skip to content

Commit aa6132d

Browse files
committed
fix: use PyCapsule Interface instead of Dataframe Interchange Protocol
1 parent b4e5f8d commit aa6132d

File tree

4 files changed

+32
-37
lines changed

4 files changed

+32
-37
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dev = [
4242
"mypy",
4343
"pandas-stubs",
4444
"pre-commit",
45+
"pyarrow",
4546
"flit",
4647
]
4748
docs = [

seaborn/_core/data.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ def _assign_variables(
269269

270270
def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
271271
"""Convert the data source object to a common union representation."""
272-
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
272+
if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"):
273273
# Check for pd.DataFrame inheritance could be removed once
274-
# minimal pandas version supports dataframe interchange (1.5.0).
274+
# minimal pandas version supports PyCapsule Interface (2.2).
275275
data = convert_dataframe_to_pandas(data)
276276
elif data is not None and not isinstance(data, Mapping):
277277
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
@@ -285,35 +285,29 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
285285
if isinstance(data, pd.DataFrame):
286286
return data
287287

288-
if not hasattr(pd.api, "interchange"):
289-
msg = (
290-
"Support for non-pandas DataFrame objects requires a version of pandas "
291-
"that implements the DataFrame interchange protocol. Please upgrade "
292-
"your pandas version or coerce your data to pandas before passing "
293-
"it to seaborn."
294-
)
295-
raise TypeError(msg)
296-
297-
if _version_predates(pd, "2.0.2"):
298-
msg = (
299-
"DataFrame interchange with pandas<2.0.2 has some known issues. "
300-
f"You are using pandas {pd.__version__}. "
301-
"Continuing, but it is recommended to carefully inspect the results and to "
302-
"consider upgrading."
303-
)
304-
warnings.warn(msg, stacklevel=2)
305-
306-
try:
307-
# This is going to convert all columns in the input dataframe, even though
308-
# we may only need one or two of them. It would be more efficient to select
309-
# the columns that are going to be used in the plot prior to interchange.
310-
# Solving that in general is a hard problem, especially with the objects
311-
# interface where variables passed in Plot() may only be referenced later
312-
# in Plot.add(). But noting here in case this seems to be a bottleneck.
313-
return pd.api.interchange.from_dataframe(data)
314-
except Exception as err:
315-
msg = (
316-
"Encountered an exception when converting data source "
317-
"to a pandas DataFrame. See traceback above for details."
318-
)
319-
raise RuntimeError(msg) from err
288+
if hasattr(data, '__arrow_c_stream__'):
289+
try:
290+
import pyarrow
291+
except ImportError as err:
292+
msg = "PyArrow is required for non-pandas Dataframe support."
293+
raise RuntimeError(msg) from err
294+
if _version_predates(pyarrow, '14.0.0'):
295+
msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support."
296+
raise RuntimeError(msg)
297+
try:
298+
# This is going to convert all columns in the input dataframe, even though
299+
# we may only need one or two of them. It would be more efficient to select
300+
# the columns that are going to be used in the plot prior to interchange.
301+
# Solving that in general is a hard problem, especially with the objects
302+
# interface where variables passed in Plot() may only be referenced later
303+
# in Plot.add(). But noting here in case this seems to be a bottleneck.
304+
return pyarrow.table(data).to_pandas()
305+
except Exception as err:
306+
msg = (
307+
"Encountered an exception when converting data source "
308+
"to a pandas DataFrame. See traceback above for details."
309+
)
310+
raise RuntimeError(msg) from err
311+
312+
msg = f"Expected object which implements '__arrow_c_stream__' from the PyCapsule Interface, got: {type(data)}"
313+
raise TypeError(msg)

tests/_core/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def test_data_interchange(self, mock_long_df, long_df):
425425
)
426426
def test_data_interchange_failure(self, mock_long_df):
427427

428-
mock_long_df._data = None # Break __dataframe__()
428+
mock_long_df.__arrow_c_stream__ = lambda x: 1/0 # Break __arrow_c_stream__()
429429
with pytest.raises(RuntimeError, match="Encountered an exception"):
430430
PlotData(mock_long_df, {"x": "x"})
431431

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ class MockInterchangeableDataFrame:
188188
def __init__(self, data):
189189
self._data = data
190190

191-
def __dataframe__(self, *args, **kwargs):
192-
return self._data.__dataframe__(*args, **kwargs)
191+
def __arrow_c_stream__(self, *args, **kwargs):
192+
return self._data.__arrow_c_stream__()
193193

194194

195195
@pytest.fixture

0 commit comments

Comments
 (0)