-
Notifications
You must be signed in to change notification settings - Fork 170
fix(typing): Overhaul @overloads in nw.from_native
#3125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Will close #1928
Absolutely not doing this with another class
wayyy more common in this case
The `TypedDict`s are clear enough
narwhals/typing.py
Outdated
| class NativeDataFrame(Sized, NativeFrame, Protocol): | ||
| def drop(self, *args: Any, **kwargs: Any) -> Any: ... | ||
|
|
||
| class NativeLazyFrame(NativeFrame, Protocol): | ||
| def explain(self, *args: Any, **kwargs: Any) -> Any: ... | ||
|
|
||
| # Needs to have something `NativeDataFrame` doesn't? | ||
| class NativeSeries(Sized, Iterable[Any], Protocol): | ||
| def filter(self, *args: Any, **kwargs: Any) -> Any: ... | ||
| # `pd.DataFrame` has this - the others don't | ||
| def value_counts(self, *args: Any, **kwargs: Any) -> Any: ... | ||
| # `pl.DataFrame` has this - the others don't | ||
| def unique(self, *args: Any, **kwargs: Any) -> Any: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue 1
NativeDataFrame and NativeSeries could overlap.
Although we don't define NativeDataFrame.filter - the classes that match it do:
That was leading to the strange @overload order and parameters changing the return type.
But now if we started with NativeSeries, then both series_only and allow_series will preserve that detail:
Prove it bro
narwhals/tests/translate/from_native_test.py
Lines 621 to 723 in a0095b1
| def test_from_native_series_exhaustive() -> None: # noqa: PLR0914, PLR0915 | |
| pytest.importorskip("polars") | |
| pytest.importorskip("pandas") | |
| pytest.importorskip("pyarrow") | |
| pytest.importorskip("typing_extensions") | |
| import pandas as pd | |
| import polars as pl | |
| import pyarrow as pa | |
| from typing_extensions import assert_type | |
| pl_ser = pl.Series([1, 2, 3]) | |
| pd_ser = cast("pd.Series[Any]", pd.Series([1, 2, 3])) | |
| pa_ser = cast("pa.ChunkedArray[Any]", pa.chunked_array([pa.array([1])])) # type: ignore[redundant-cast] | |
| pl_1 = nw.from_native(pl_ser, series_only=True) | |
| pl_2 = nw.from_native(pl_ser, allow_series=True) | |
| pl_3 = nw.from_native(pl_ser, eager_only=True, series_only=True) | |
| pl_4 = nw.from_native(pl_ser, eager_only=True, series_only=True, allow_series=True) | |
| pl_5 = nw.from_native(pl_ser, eager_only=True, allow_series=True) | |
| pl_6 = nw.from_native(pl_ser, series_only=True, allow_series=True) | |
| pl_7 = nw.from_native(pl_ser, series_only=True, pass_through=True) | |
| pl_8 = nw.from_native(pl_ser, allow_series=True, pass_through=True) | |
| pl_9 = nw.from_native(pl_ser, eager_only=True, series_only=True, pass_through=True) | |
| pl_10 = nw.from_native( | |
| pl_ser, eager_only=True, series_only=True, allow_series=True, pass_through=True | |
| ) | |
| pl_11 = nw.from_native(pl_ser, eager_only=True, allow_series=True, pass_through=True) | |
| pl_12 = nw.from_native(pl_ser, series_only=True, allow_series=True, pass_through=True) | |
| pls = pl_1, pl_2, pl_3, pl_4, pl_5, pl_6, pl_7, pl_8, pl_9, pl_10, pl_11, pl_12 | |
| assert_type(pl_1, nw.Series[pl.Series]) | |
| assert_type(pl_2, nw.Series[pl.Series]) | |
| assert_type(pl_3, nw.Series[pl.Series]) | |
| assert_type(pl_4, nw.Series[pl.Series]) | |
| assert_type(pl_5, nw.Series[pl.Series]) | |
| assert_type(pl_6, nw.Series[pl.Series]) | |
| assert_type(pl_7, nw.Series[pl.Series]) | |
| assert_type(pl_8, nw.Series[pl.Series]) | |
| assert_type(pl_9, nw.Series[pl.Series]) | |
| assert_type(pl_10, nw.Series[pl.Series]) | |
| assert_type(pl_11, nw.Series[pl.Series]) | |
| assert_type(pl_12, nw.Series[pl.Series]) | |
| pd_1 = nw.from_native(pd_ser, series_only=True) | |
| pd_2 = nw.from_native(pd_ser, allow_series=True) | |
| pd_3 = nw.from_native(pd_ser, eager_only=True, series_only=True) | |
| pd_4 = nw.from_native(pd_ser, eager_only=True, series_only=True, allow_series=True) | |
| pd_5 = nw.from_native(pd_ser, eager_only=True, allow_series=True) | |
| pd_6 = nw.from_native(pd_ser, series_only=True, allow_series=True) | |
| pd_7 = nw.from_native(pd_ser, series_only=True, pass_through=True) | |
| pd_8 = nw.from_native(pd_ser, allow_series=True, pass_through=True) | |
| pd_9 = nw.from_native(pd_ser, eager_only=True, series_only=True, pass_through=True) | |
| pd_10 = nw.from_native( | |
| pd_ser, eager_only=True, series_only=True, allow_series=True, pass_through=True | |
| ) | |
| pd_11 = nw.from_native(pd_ser, eager_only=True, allow_series=True, pass_through=True) | |
| pd_12 = nw.from_native(pd_ser, series_only=True, allow_series=True, pass_through=True) | |
| pds = pd_1, pd_2, pd_3, pd_4, pd_5, pd_6, pd_7, pd_8, pd_9, pd_10, pd_11, pd_12 | |
| assert_type(pd_1, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_2, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_3, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_4, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_5, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_6, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_7, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_8, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_9, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_10, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_11, nw.Series["pd.Series[Any]"]) | |
| assert_type(pd_12, nw.Series["pd.Series[Any]"]) | |
| pa_1 = nw.from_native(pa_ser, series_only=True) | |
| pa_2 = nw.from_native(pa_ser, allow_series=True) | |
| pa_3 = nw.from_native(pa_ser, eager_only=True, series_only=True) | |
| pa_4 = nw.from_native(pa_ser, eager_only=True, series_only=True, allow_series=True) | |
| pa_5 = nw.from_native(pa_ser, eager_only=True, allow_series=True) | |
| pa_6 = nw.from_native(pa_ser, series_only=True, allow_series=True) | |
| pa_7 = nw.from_native(pa_ser, series_only=True, pass_through=True) | |
| pa_8 = nw.from_native(pa_ser, allow_series=True, pass_through=True) | |
| pa_9 = nw.from_native(pa_ser, eager_only=True, series_only=True, pass_through=True) | |
| pa_10 = nw.from_native( | |
| pa_ser, eager_only=True, series_only=True, allow_series=True, pass_through=True | |
| ) | |
| pa_11 = nw.from_native(pa_ser, eager_only=True, allow_series=True, pass_through=True) | |
| pa_12 = nw.from_native(pa_ser, series_only=True, allow_series=True, pass_through=True) | |
| pas = pa_1, pa_2, pa_3, pa_4, pa_5, pa_6, pa_7, pa_8, pa_9, pa_10, pa_11, pa_12 | |
| assert_type(pa_1, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_2, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_3, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_4, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_5, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_6, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_7, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_8, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_9, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_10, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_11, nw.Series["pa.ChunkedArray[Any]"]) | |
| assert_type(pa_12, nw.Series["pa.ChunkedArray[Any]"]) | |
| for series in chain(pls, pds, pas): | |
| assert isinstance(series, nw.Series) |
narwhals/translate.py
Outdated
| class OnlyEager(TypedDict, total=False): | ||
| pass_through: bool | ||
| eager_only: Required[Literal[True]] | ||
| series_only: bool | ||
| allow_series: bool | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue 2
There's too many combinations we need to account for in @overloads.
- 3x different, true
native_objects - 3x narwhals objects pretending to be native
- 4x flags
- They all have a default, so needs
... allow_series: bool | Nonejust adds to it
- They all have a default, so needs
So instead of that:
from_native@overloads in nw.from_native
|
@MarcoGorelli this one isn't ready for a proper review yet, but I was curious about your thoughts on the idea generally? I did some testing locally and I think depending on how these issues have progressed - docs might be one more area for experimenting |
tbh I'm still not too keen on matching on the native object, i kind of feel like |
That isn't possible due to the issue from: Which is also why Possibly these sections of the spec may be helpful?
These are the issues that come from having such a heavily polymorphic function I'm afraid π’ |
`v1`, `v2` can extend and reuse these defs
Need `OnlySeries` to finish the other replacements
ooooh I think we're almost ready
I've updated the PR description with a table showing this in more detail
Closes #1928
What type of PR is this? (check all applicable)
Related issues
DataFrame,LazyFrameΒ #2944_native.pyΒ #3086Checklist
If you have comments or can explain your changes, please do so below
Note
The positive diff is from new quite exhaustive tests
narwhals.translatehas shrunk a lot π₯³TODO
Native{DataFrame,Series}nw._nativeNativeSeriesfiltermight not be needed anymoreTypedDicts intonw._translateTypedDicts forstablev1(thread)v1AllowAnyvariantsAllowSeriesvariantsOnlySeriesvariantsExcludeSeriesvariantsAllowLazyvariantsOnlyEagerOrInterchangevariantsv2mainfrom_native()@overloads 1-3 to useTypedDictsmkdocspresents the new typingTypedDicts?Unpackwork out-of-the-box?@overloads useful/noisy?Previews
fix(typing): More gracefully handle narwhals in overloads
Anywas previously being used on the first few overloads to solve a recursion issue.But that had the downside of swallowing completions until you matched an overload with
native_object:Show before
But now, we have something more useful
Show after
Changes to
Native{Series,DataFrame}As mentioned in (#3125 (comment)), these protocols needed to be
adapted to make them disjoint across
pandas,polars,pyarrow:dropfilteruniquevalue_countspd.Seriespl.Seriespa.ChunkedArraypd.DataFramepl.DataFramepa.Table