diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a006a662cb..6793abbe15 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -685,6 +685,23 @@ def _repr_html_(self) -> str: except ImportError: return preview._repr_html_() + @DataframePublicAPI + def _repr_mimebundle_( + self, include: Iterable[str] | None = None, exclude: Iterable[str] | None = None + ) -> dict[str, str]: + include_set = set(include) if include is not None else None + exclude_set = set(exclude) if exclude is not None else set() + + mimebundle: dict[str, str] = {} + + if (include_set is None or "text/plain" in include_set) and "text/plain" not in exclude_set: + mimebundle["text/plain"] = self.__repr__() + + if (include_set is None or "text/html" in include_set) and "text/html" not in exclude_set: + mimebundle["text/html"] = self._repr_html_() + + return mimebundle + ### # Creation methods ### diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index f657420585..2c2c8d2fee 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -275,6 +275,27 @@ def test_repr_with_html_string(): ) +@pytest.mark.parametrize( + "kwargs,expected_keys", + [ + ({}, {"text/plain", "text/html"}), + ({"include": {"text/plain"}}, {"text/plain"}), + ({"exclude": {"text/html"}}, {"text/plain"}), + ({"include": {"text/plain"}, "exclude": {"text/plain"}}, set()), + ], +) +def test_repr_mimebundle(make_df, kwargs, expected_keys): + df = make_df({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + + bundle = df._repr_mimebundle_(**kwargs) + + assert set(bundle.keys()) == expected_keys + if "text/plain" in bundle: + assert bundle["text/plain"] == df.__repr__() + if "text/html" in bundle: + assert bundle["text/html"] == df._repr_html_() + + class MyObj: def __repr__(self) -> str: return "myobj-custom-repr"