Skip to content

Commit 1de571b

Browse files
authored
perf: Optimize, generalize _DelayedCategories -> _DeferredIterable (#2421)
1 parent fa81711 commit 1de571b

File tree

6 files changed

+90
-38
lines changed

6 files changed

+90
-38
lines changed

narwhals/_pandas_like/utils.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77
from typing import TYPE_CHECKING
88
from typing import Any
99
from typing import Callable
10+
from typing import Literal
1011
from typing import Sequence
1112
from typing import Sized
1213
from typing import TypeVar
1314

1415
import pandas as pd
1516

1617
from narwhals._compliant.series import EagerSeriesNamespace
17-
from narwhals.dtypes import _DelayedCategories
1818
from narwhals.exceptions import ColumnNotFoundError
1919
from narwhals.exceptions import DuplicateError
2020
from narwhals.exceptions import ShapeError
2121
from narwhals.utils import Implementation
2222
from narwhals.utils import Version
23+
from narwhals.utils import _DeferredIterable
2324
from narwhals.utils import isinstance_or_issubclass
2425

2526
T = TypeVar("T", bound=Sized)
@@ -255,9 +256,7 @@ def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) ->
255256
if dtype.startswith("dictionary<"):
256257
return dtypes.Categorical()
257258
if dtype == "category":
258-
return native_categorical_to_narwhals_dtype(
259-
native_dtype, version, lambda: tuple(native_dtype.categories)
260-
)
259+
return native_categorical_to_narwhals_dtype(native_dtype, version)
261260
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
262261
match_ := PATTERN_PA_DATETIME.match(dtype)
263262
):
@@ -305,20 +304,33 @@ def object_native_to_narwhals_dtype(
305304
def native_categorical_to_narwhals_dtype(
306305
native_dtype: pd.CategoricalDtype,
307306
version: Version,
308-
get_categories: Callable[[], tuple[str, ...]],
307+
implementation: Literal[Implementation.CUDF] | None = None,
309308
) -> DType:
310309
dtypes = version.dtypes
311310
if version is Version.V1:
312311
return dtypes.Categorical()
313312
if native_dtype.ordered:
314-
return dtypes.Enum(_DelayedCategories(get_categories))
313+
into_iter = (
314+
_cudf_categorical_to_list(native_dtype)
315+
if implementation is Implementation.CUDF
316+
else native_dtype.categories.to_list
317+
)
318+
return dtypes.Enum(_DeferredIterable(into_iter))
315319
return dtypes.Categorical()
316320

317321

318-
def native_to_narwhals_dtype(
322+
def _cudf_categorical_to_list(
319323
native_dtype: Any,
320-
version: Version,
321-
implementation: Implementation,
324+
) -> Callable[[], list[Any]]: # pragma: no cover
325+
# NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype
326+
def fn() -> list[Any]:
327+
return native_dtype.categories.to_arrow().to_pylist()
328+
329+
return fn
330+
331+
332+
def native_to_narwhals_dtype(
333+
native_dtype: Any, version: Version, implementation: Implementation
322334
) -> DType:
323335
str_dtype = str(native_dtype)
324336

@@ -333,8 +345,9 @@ def native_to_narwhals_dtype(
333345
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
334346
if str_dtype == "category" and implementation.is_cudf():
335347
# https://github.com/rapidsai/cudf/issues/18536
348+
# https://github.com/rapidsai/cudf/issues/14027
336349
return native_categorical_to_narwhals_dtype(
337-
native_dtype, version, lambda: tuple(native_dtype.categories.to_pandas())
350+
native_dtype, version, Implementation.CUDF
338351
)
339352
if str_dtype != "object":
340353
return non_object_native_to_narwhals_dtype(native_dtype, version)

narwhals/_polars/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
from typing import Iterator
88
from typing import Mapping
99
from typing import TypeVar
10+
from typing import cast
1011
from typing import overload
1112

1213
import polars as pl
1314

14-
from narwhals.dtypes import _DelayedCategories
1515
from narwhals.exceptions import ColumnNotFoundError
1616
from narwhals.exceptions import ComputeError
1717
from narwhals.exceptions import DuplicateError
1818
from narwhals.exceptions import InvalidOperationError
1919
from narwhals.exceptions import NarwhalsError
2020
from narwhals.exceptions import ShapeError
2121
from narwhals.utils import Version
22+
from narwhals.utils import _DeferredIterable
2223
from narwhals.utils import isinstance_or_issubclass
2324

2425
if TYPE_CHECKING:
@@ -101,7 +102,12 @@ def native_to_narwhals_dtype(
101102
if isinstance_or_issubclass(dtype, pl.Enum):
102103
if version is Version.V1:
103104
return dtypes.Enum() # type: ignore[call-arg]
104-
return dtypes.Enum(_DelayedCategories(lambda: tuple(dtype.categories)))
105+
categories = _DeferredIterable(
106+
dtype.categories.to_list
107+
if backend_version >= (0, 20, 4)
108+
else lambda: cast("list[str]", dtype.categories)
109+
)
110+
return dtypes.Enum(categories)
105111
if dtype == pl.Date:
106112
return dtypes.Date()
107113
if isinstance_or_issubclass(dtype, pl.Datetime):

narwhals/dtypes.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from typing import Iterable
99
from typing import Mapping
1010

11+
from narwhals.utils import _DeferredIterable
1112
from narwhals.utils import isinstance_or_issubclass
1213

1314
if TYPE_CHECKING:
14-
from typing import Callable
1515
from typing import Iterator
1616
from typing import Sequence
1717

@@ -461,22 +461,6 @@ class Categorical(DType):
461461
"""
462462

463463

464-
class _DelayedCategories:
465-
"""Store callable which produces tupleified version of Enum's categories.
466-
467-
!!! warning
468-
This class is not meant to be instantiated directly and exists
469-
for internal usage only.
470-
"""
471-
472-
def __init__(self, get_categories: Callable[[], tuple[str, ...]]) -> None:
473-
self.get_categories = get_categories
474-
475-
def __iter__(self) -> Iterator[str]: # pragma: no cover
476-
msg = "This is only provided for type-checking and should not be called"
477-
raise AssertionError(msg)
478-
479-
480464
class Enum(DType):
481465
"""A fixed categorical encoding of a unique set of strings.
482466
@@ -490,10 +474,10 @@ class Enum(DType):
490474
"""
491475

492476
def __init__(self, categories: Iterable[str] | type[enum.Enum]) -> None:
493-
self._delayed_categories: _DelayedCategories | None = None
477+
self._delayed_categories: _DeferredIterable[str] | None = None
494478
self._cached_categories: tuple[str, ...] | None = None
495479

496-
if isinstance(categories, _DelayedCategories):
480+
if isinstance(categories, _DeferredIterable):
497481
self._delayed_categories = categories
498482
elif isinstance(categories, type) and issubclass(categories, enum.Enum):
499483
self._cached_categories = tuple(member.value for member in categories)
@@ -502,10 +486,14 @@ def __init__(self, categories: Iterable[str] | type[enum.Enum]) -> None:
502486

503487
@property
504488
def categories(self) -> tuple[str, ...]:
505-
if self._cached_categories is None:
506-
assert self._delayed_categories is not None # noqa: S101
507-
self._cached_categories = self._delayed_categories.get_categories()
508-
return self._cached_categories
489+
if cached := self._cached_categories:
490+
return cached
491+
elif delayed := self._delayed_categories:
492+
self._cached_categories = delayed.to_tuple()
493+
return self._cached_categories
494+
else: # pragma: no cover
495+
msg = f"Internal structure of {type(self).__name__!r} is invalid."
496+
raise TypeError(msg)
509497

510498
def __eq__(self, other: object) -> bool:
511499
# allow comparing object instances to class

narwhals/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from typing import Any
1515
from typing import Callable
1616
from typing import Container
17+
from typing import Generic
1718
from typing import Iterable
19+
from typing import Iterator
1820
from typing import Literal
1921
from typing import Protocol
2022
from typing import Sequence
@@ -102,7 +104,7 @@
102104
FrameOrSeriesT = TypeVar(
103105
"FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series[Any]]
104106
)
105-
_T = TypeVar("_T")
107+
106108
_T1 = TypeVar("_T1")
107109
_T2 = TypeVar("_T2")
108110
_T3 = TypeVar("_T3")
@@ -150,6 +152,7 @@ class _StoresColumns(Protocol):
150152
def columns(self) -> Sequence[str]: ...
151153

152154

155+
_T = TypeVar("_T")
153156
NativeT_co = TypeVar("NativeT_co", covariant=True)
154157
CompliantT_co = TypeVar("CompliantT_co", covariant=True)
155158
_ContextT = TypeVar("_ContextT", bound="_FullContext")
@@ -1946,3 +1949,18 @@ def decorate(init_child: _Constructor[_T, P, R2], /) -> _Constructor[_T, P, R2]:
19461949
raise TypeError(msg)
19471950

19481951
return decorate
1952+
1953+
1954+
class _DeferredIterable(Generic[_T]):
1955+
"""Store a callable producing an iterable to defer collection until we need it."""
1956+
1957+
def __init__(self, into_iter: Callable[[], Iterable[_T]], /) -> None:
1958+
self._into_iter: Callable[[], Iterable[_T]] = into_iter
1959+
1960+
def __iter__(self) -> Iterator[_T]:
1961+
yield from self._into_iter()
1962+
1963+
def to_tuple(self) -> tuple[_T, ...]:
1964+
# Collect and return as a `tuple`.
1965+
it = self._into_iter()
1966+
return it if isinstance(it, tuple) else tuple(it)

tests/dtypes_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ def test_enum_categories_immutable() -> None:
435435
dtype = nw.Enum(["a", "b"])
436436
with pytest.raises(TypeError, match="does not support item assignment"):
437437
dtype.categories[0] = "c" # type: ignore[index]
438+
with pytest.raises(AttributeError):
439+
dtype.categories = "a", "b", "c" # type: ignore[misc]
438440

439441

440442
def test_enum_repr_pd() -> None:
@@ -444,7 +446,8 @@ def test_enum_repr_pd() -> None:
444446
)
445447
)
446448
dtype = df.schema["a"]
447-
assert dtype.categories == ("broccoli", "cabbage") # type: ignore[attr-defined]
449+
assert isinstance(dtype, nw.Enum)
450+
assert dtype.categories == ("broccoli", "cabbage")
448451
assert "Enum(categories=['broccoli', 'cabbage'])" in str(dtype)
449452

450453

@@ -458,7 +461,8 @@ def test_enum_repr_pl() -> None:
458461
)
459462
)
460463
dtype = df.schema["a"]
461-
assert dtype.categories == ("broccoli", "cabbage") # type: ignore[attr-defined]
464+
assert isinstance(dtype, nw.Enum)
465+
assert dtype.categories == ("broccoli", "cabbage")
462466
assert "Enum(categories=['broccoli', 'cabbage'])" in repr(dtype)
463467

464468

tests/utils_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
import re
44
import string
55
from dataclasses import dataclass
6+
from itertools import chain
67
from typing import TYPE_CHECKING
78
from typing import Any
9+
from typing import Callable
10+
from typing import Iterable
11+
from typing import Iterator
812
from typing import Protocol
913
from typing import cast
1014

@@ -22,6 +26,7 @@
2226
from narwhals.exceptions import ColumnNotFoundError
2327
from narwhals.utils import Implementation
2428
from narwhals.utils import Version
29+
from narwhals.utils import _DeferredIterable
2530
from narwhals.utils import check_column_exists
2631
from narwhals.utils import deprecate_native_namespace
2732
from narwhals.utils import parse_version
@@ -557,3 +562,21 @@ def repeat(self, n: int) -> str:
557562
)
558563
with pytest.raises(NotImplementedError, match=pattern):
559564
v_05.concat("never")
565+
566+
567+
def test_deferred_iterable() -> None:
568+
def to_upper(it: Iterable[str]) -> Callable[[], Iterator[str]]:
569+
def fn() -> Iterator[str]:
570+
for el in it:
571+
yield el.capitalize()
572+
573+
return fn
574+
575+
iterable = list("hello")
576+
deferred_1 = _DeferredIterable(iterable.copy)
577+
deferred_2 = _DeferredIterable(to_upper(iterable))
578+
579+
assert deferred_1.to_tuple() == tuple("hello")
580+
assert next(iter(deferred_1)) == "h"
581+
assert list(deferred_1) == list("hello")
582+
assert "".join(chain(deferred_1, deferred_2)) == "helloHELLO"

0 commit comments

Comments
 (0)