Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from typing_extensions import Self, TypeAlias

from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co
from narwhals._plan.expr import Expr, Selector
from narwhals._plan.expressions.expr import Alias, Cast, Column
from narwhals._plan.meta import MetaNamespace
from narwhals._plan.protocols import Ctx, FrameT_contra, R_co
from narwhals._plan.typing import ExprIRT2, MapIR, Seq
from narwhals.dtypes import DType

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_plan/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing_extensions import TypeIs

from narwhals._plan import expressions as ir
from narwhals._plan.compliant.series import CompliantSeries
from narwhals._plan.expr import Expr
from narwhals._plan.protocols import CompliantSeries
from narwhals._plan.series import Series
from narwhals._plan.typing import NativeSeriesT, Seq
from narwhals.typing import NonNestedLiteral
Expand Down
11 changes: 3 additions & 8 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.compliant.dataframe import EagerDataFrame
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._plan.protocols import EagerDataFrame, namespace
from narwhals._plan.typing import Seq
from narwhals._utils import Version, parse_columns_to_drop
from narwhals.schema import Schema
Expand All @@ -23,10 +24,9 @@

from typing_extensions import Self

from narwhals._arrow.typing import ChunkedArrayAny
from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._plan.arrow.namespace import ArrowNamespace
from narwhals._plan.dataframe import DataFrame as NwDataFrame
from narwhals._plan.expressions import ExprIR, NamedIR
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import Seq
Expand Down Expand Up @@ -62,11 +62,6 @@ def schema(self) -> dict[str, DType]:
def __len__(self) -> int:
return self.native.num_rows

def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]:
from narwhals._plan.dataframe import DataFrame

return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self)

@classmethod
def from_dict(
cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None
Expand Down
23 changes: 4 additions & 19 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
from narwhals._plan.compliant.column import ExprDispatch
from narwhals._plan.compliant.expr import EagerExpr
from narwhals._plan.compliant.scalar import EagerScalar
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace
from narwhals._utils import (
Implementation,
Version,
Expand Down Expand Up @@ -449,28 +452,10 @@ def broadcast(self, length: int) -> Series:
chunked = fn.chunked_array(pa_repeat(scalar, length))
return Series.from_native(chunked, self.name, version=self.version)

def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(0), name)

def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(0), name)

def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(1), name)

def std(self, node: Std, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(None, pa.null()), name)

def var(self, node: Var, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(None, pa.null()), name)

def count(self, node: Count, frame: Frame, name: str) -> Scalar:
native = node.expr.dispatch(self, frame, name).native
return self._with_native(pa.scalar(1 if native.is_valid else 0), name)

def len(self, node: Len, frame: Frame, name: str) -> Scalar:
return self._with_native(pa.scalar(1), name)

filter = not_implemented()
over = not_implemented()
over_ordered = not_implemented()
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_plan/arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from narwhals._plan._guards import is_agg_expr, is_function_expr
from narwhals._plan.arrow import acero, functions as fn, options
from narwhals._plan.common import dispatch_method_name, temp
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
from narwhals._plan.expressions import aggregation as agg
from narwhals._plan.protocols import EagerDataFrameGroupBy
from narwhals._utils import Implementation
from narwhals.exceptions import InvalidOperationError

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_plan/arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._plan._guards import is_tuple_of
from narwhals._plan.arrow import functions as fn
from narwhals._plan.compliant.namespace import EagerNamespace
from narwhals._plan.expressions.literal import is_literal_scalar
from narwhals._plan.protocols import EagerNamespace
from narwhals._utils import Version
from narwhals.exceptions import InvalidOperationError

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_plan/arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype
from narwhals._plan.arrow import functions as fn
from narwhals._plan.protocols import CompliantSeries
from narwhals._plan.compliant.series import CompliantSeries
from narwhals._utils import Version
from narwhals.dependencies import is_numpy_array_1d

Expand Down
13 changes: 12 additions & 1 deletion narwhals/_plan/arrow/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload

from narwhals._typing_compat import TypeVar
Expand All @@ -23,10 +23,21 @@
)
from typing_extensions import TypeAlias

from narwhals.typing import NativeDataFrame, NativeSeries

StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]"
IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type"
IntegerScalar: TypeAlias = "Scalar[IntegerType]"

class NativeArrowSeries(NativeSeries, Protocol):
@property
def chunks(self) -> list[Any]: ...

class NativeArrowDataFrame(NativeDataFrame, Protocol):
def column(self, *args: Any, **kwds: Any) -> NativeArrowSeries: ...
@property
def columns(self) -> Sequence[NativeArrowSeries]: ...


ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]")
ScalarPT_contra = TypeVar(
Expand Down
1 change: 1 addition & 0 deletions narwhals/_plan/compliant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import annotations
99 changes: 99 additions & 0 deletions narwhals/_plan/compliant/column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from collections.abc import Sized
from typing import TYPE_CHECKING, Protocol

from narwhals._plan.common import flatten_hash_safe
from narwhals._plan.compliant.typing import (
FrameT_contra,
HasVersion,
LengthT,
NamespaceT_co,
R_co,
SeriesT,
)

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence

from typing_extensions import Self

from narwhals._plan import expressions as ir
from narwhals._plan.typing import OneOrIterable


class SupportsBroadcast(Protocol[SeriesT, LengthT]):
"""Minimal broadcasting for `Expr` results."""

def _length(self) -> LengthT:
"""Return the length of the current expression."""
...

@classmethod
def _length_all(
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
) -> Sequence[LengthT]:
return [e._length() for e in exprs]

@classmethod
def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT:
"""Return the maximum length among `exprs`."""
...

@classmethod
def _length_required(
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
) -> LengthT | None:
"""Return the broadcast length, if all lengths do not equal the maximum."""

@classmethod
def align(
cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]]
) -> Iterator[SeriesT]:
exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs))
length = cls._length_required(exprs)
if length is None:
for e in exprs:
yield e.to_series()
else:
for e in exprs:
yield e.broadcast(length)

def broadcast(self, length: LengthT, /) -> SeriesT: ...
@classmethod
def from_series(cls, series: SeriesT, /) -> Self: ...
def to_series(self) -> SeriesT: ...


class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]):
"""Determines expression length via the size of the container."""

def _length(self) -> int:
return len(self)

@classmethod
def _length_max(cls, lengths: Sequence[int], /) -> int:
return max(lengths)

@classmethod
def _length_required(
cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], /
) -> int | None:
lengths = cls._length_all(exprs)
max_length = cls._length_max(lengths)
required = any(len_ != max_length for len_ in lengths)
return max_length if required else None


class ExprDispatch(HasVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
# NOTE: Needs to stay `covariant` and never be used as a parameter
def __narwhals_namespace__(self) -> NamespaceT_co: ...
@classmethod
def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co:
obj = cls.__new__(cls)
obj._version = frame.version
return node.dispatch(obj, frame, name)

@classmethod
def from_named_ir(cls, named_ir: ir.NamedIR[ir.ExprIR], frame: FrameT_contra) -> R_co:
return cls.from_ir(named_ir.expr, frame, named_ir.name)
Loading
Loading