diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73a4b9ccf2..d0d004e1dd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,12 +62,12 @@ repos: name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | (?x) - import\ narwhals.dtypes| - from\ narwhals\ import\ dtypes| - from\ narwhals.dtypes\ import\ [^D_]+| - import\ narwhals.stable.v1.dtypes| - from\ narwhals.stable\.v.\ import\ dtypes| - from\ narwhals.stable\.v.\.dtypes\ import + import\ narwhals(\.stable\.v\d)?\.dtypes| + from\ narwhals(\.stable\.v\d)?\ import\ dtypes| + ^from\ narwhals(\.stable\.v\d)?\.dtypes\ import + \ (DType,\ )? + ((Datetime|Duration|Enum)(,\ )?)+ + ((,\ )?DType)? language: pygrep files: ^narwhals/ exclude: | diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index 8b02e209cb..553a5d9443 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -22,6 +22,7 @@ Here are the top-level functions available in Narwhals. - from_numpy - generate_temporary_column_name - get_native_namespace + - int_range - is_ordered_categorical - len - lit diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 52e0eb0506..aeff5646a7 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -59,6 +59,7 @@ from_dict, from_dicts, from_numpy, + int_range, len_ as len, lit, max, @@ -145,6 +146,7 @@ "from_numpy", "generate_temporary_column_name", "get_native_namespace", + "int_range", "is_ordered_categorical", "len", "lit", diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index c972b9f279..80727119e5 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -8,7 +8,7 @@ import pyarrow.compute as pc from narwhals._arrow.series import ArrowSeries -from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._arrow.utils import int_range, native_to_narwhals_dtype from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind from narwhals._utils import ( @@ -507,16 +507,11 @@ def to_dict( def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: plx = self.__narwhals_namespace__() if order_by is None: - import numpy as np # ignore-banned-import - - data = pa.array(np.arange(len(self), dtype=np.int64)) - row_index = plx._expr._from_series( - plx._series.from_iterable(data, context=self, name=name) - ) + row_index = plx._expr._from_series(plx.int_range_eager(0, len(self))) else: rank = plx.col(order_by[0]).rank("ordinal", descending=False) - row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name) - return self.select(row_index, plx.all()) + row_index = rank.over(partition_by=[], order_by=order_by) - 1 + return self.select(row_index.alias(name), plx.all()) def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self: if isinstance(predicate, list): @@ -695,10 +690,8 @@ def write_csv(self, file: str | Path | BytesIO | None) -> str | None: return None def is_unique(self) -> ArrowSeries: - import numpy as np # ignore-banned-import - col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) - row_index = pa.array(np.arange(len(self))) + row_index = int_range(0, len(self)) keep_idx = ( self.native.append_column(col_token, row_index) .group_by(self.columns) @@ -722,8 +715,6 @@ def unique( ) -> Self: # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. - import numpy as np # ignore-banned-import - if subset and (error := self._check_columns_exist(subset)): raise error subset = list(subset or self.columns) @@ -750,7 +741,7 @@ def unique( else: native = self.native keep_idx_native = ( - native.append_column(col_token, pa.array(np.arange(len(self)))) + native.append_column(col_token, int_range(0, len(self))) .group_by(subset) .aggregate([(col_token, agg_func)]) .column(f"{col_token}_{agg_func}") @@ -769,6 +760,8 @@ def gather_every(self, n: int, offset: int) -> Self: def to_arrow(self) -> pa.Table: return self.native + # TODO @dangotbanned: Replace `np.arange` w/ `utils.int_range` + # https://github.com/narwhals-dev/narwhals/issues/2722#issuecomment-3097350688 def sample( self, n: int | None, diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 3799aa87b2..e6282b92f2 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -12,13 +12,19 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.selectors import ArrowSelectorNamespace from narwhals._arrow.series import ArrowSeries -from narwhals._arrow.utils import cast_to_comparable_string_types +from narwhals._arrow.utils import ( + cast_to_comparable_string_types, + chunked_array, + int_range, + narwhals_to_native_dtype, +) from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, ) from narwhals._utils import Implementation +from narwhals.dtypes import Int64 if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -26,7 +32,7 @@ from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Version - from narwhals.typing import IntoDType, NonNestedLiteral + from narwhals.typing import IntegerDType, IntoDType, NonNestedLiteral class ArrowNamespace( @@ -278,6 +284,19 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) + def int_range_eager( + self, + start: int, + end: int, + step: int = 1, + *, + dtype: IntegerDType = Int64, + name: str = "literal", + ) -> ArrowSeries: + dtype_pa = narwhals_to_native_dtype(dtype, version=self._version) + data = int_range(start=start, end=end, step=step, dtype=dtype_pa) + return self._series.from_native(chunked_array([data]), name=name, context=self) + class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]): @property diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index fd73273605..a1de3cb69a 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -15,6 +15,7 @@ chunked_array, extract_native, floordiv_compat, + int_range, is_array_or_scalar, lit, narwhals_to_native_dtype, @@ -641,6 +642,8 @@ def zip_with(self, mask: Self, other: Self) -> Self: cond = mask.native.combine_chunks() return self._with_native(pc.if_else(cond, self.native, other.native)) + # TODO @dangotbanned: Replace `np.arange` w/ `utils.int_range` + # https://github.com/narwhals-dev/narwhals/issues/2722#issuecomment-3097350688 def sample( self, n: int | None, @@ -679,7 +682,7 @@ def fill_aux( # then it calculates the distance of each new index and the original index # if the distance is equal to or less than the limit and the original value is null, it is replaced valid_mask = pc.is_valid(arr) - indices = pa.array(np.arange(len(arr)), type=pa.int64()) + indices = int_range(0, len(arr)) if direction == "forward": valid_index = np.maximum.accumulate(np.where(valid_mask, indices, -1)) distance = indices - valid_index @@ -726,9 +729,7 @@ def is_unique(self) -> ArrowSeries: return self.to_frame().is_unique().alias(self.name) def is_first_distinct(self) -> Self: - import numpy as np # ignore-banned-import - - row_number = pa.array(np.arange(len(self))) + row_number = int_range(0, len(self)) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) first_distinct_index = ( pa.Table.from_arrays([self.native], names=[self.name]) @@ -741,9 +742,7 @@ def is_first_distinct(self) -> Self: return self._with_native(pc.is_in(row_number, first_distinct_index)) def is_last_distinct(self) -> Self: - import numpy as np # ignore-banned-import - - row_number = pa.array(np.arange(len(self))) + row_number = int_range(0, len(self)) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) last_distinct_index = ( pa.Table.from_arrays([self.native], names=[self.name]) @@ -799,6 +798,8 @@ def sort(self, *, descending: bool, nulls_last: bool) -> Self: ) return self._with_native(self.native.take(sorted_indices)) + # TODO @dangotbanned: Replace `np.arange` w/ `utils.int_range` + # https://github.com/narwhals-dev/narwhals/issues/2722#issuecomment-3097350688 def to_dummies(self, *, separator: str, drop_first: bool) -> ArrowDataFrame: import numpy as np # ignore-banned-import @@ -1165,6 +1166,8 @@ def _calculate_bins(self, bin_count: int) -> _1DArray: upper += 0.5 return self._linear_space(lower, upper, bin_count + 1) + # TODO @dangotbanned: Replace `np.arange` w/ `utils.int_range` + # https://github.com/narwhals-dev/narwhals/issues/2722#issuecomment-3097350688 def _calculate_hist(self, bins: list[float] | _1DArray) -> ArrowHistData: ser = self.native # NOTE: `mypy` refuses to resolve `ndarray.__getitem__` diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 8995a2517a..f9585d8a41 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -1,13 +1,13 @@ from __future__ import annotations from functools import lru_cache -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Final, cast import pyarrow as pa import pyarrow.compute as pc from narwhals._compliant import EagerSeriesNamespace -from narwhals._utils import Version, isinstance_or_issubclass +from narwhals._utils import Implementation, Version, isinstance_or_issubclass if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -21,6 +21,7 @@ ArrayOrScalarT1, ArrayOrScalarT2, ChunkedArrayAny, + Incomplete, NativeIntervalUnit, ScalarAny, ) @@ -57,6 +58,9 @@ def extract_regex( is_timestamp, ) +BACKEND_VERSION = Implementation.PYARROW._backend_version() +"""Static backend version for `pyarrow`.""" + UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = { "y": "year", "q": "quarter", @@ -73,6 +77,9 @@ def extract_regex( lit = pa.scalar """Alias for `pyarrow.scalar`.""" +int64: Final = pa.int64() +"""Initialized `pyarrow.types.Int64Type`.""" + def extract_py_scalar(value: Any, /) -> Any: from narwhals._arrow.series import maybe_extract_py_scalar @@ -441,4 +448,17 @@ def cast_to_comparable_string_types( return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype) +def int_range( + start: int, end: int, step: int = 1, *, dtype: pa.DataType = int64 +) -> ArrayAny: + if BACKEND_VERSION < (21, 0, 0): # pragma: no cover + import numpy as np # ignore-banned-import + + return pa.array(np.arange(start=start, stop=end, step=step), type=dtype) + # NOTE: Added in https://github.com/apache/arrow/pull/46778 + pa_arange = cast("Incomplete", pa.arange) # type: ignore[attr-defined] + arr: ArrayAny = pa_arange(start=start, stop=end, step=step) + return arr.cast(dtype) + + class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ... diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 4cc7130828..db9550f3bd 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -15,13 +15,15 @@ NativeFrameT, NativeSeriesT, ) -from narwhals._expression_parsing import is_expr, is_series +from narwhals._expression_parsing import combine_evaluate_output_names, is_expr, is_series from narwhals._utils import ( exclude_column_names, get_column_names, + not_implemented, passthrough_column_names, ) from narwhals.dependencies import is_numpy_array, is_numpy_array_2d +from narwhals.dtypes import Int64 if TYPE_CHECKING: from collections.abc import Container, Iterable, Sequence @@ -35,6 +37,7 @@ from narwhals.series import Series from narwhals.typing import ( ConcatMethod, + IntegerDType, Into1DArray, IntoDType, IntoSchema, @@ -109,6 +112,14 @@ def when( def concat_str( self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool ) -> CompliantExprT: ... + def int_range( + self, + start: CompliantExprT, + end: CompliantExprT, + step: int = 1, + *, + dtype: IntegerDType = Int64, + ) -> CompliantExprT: ... @property def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ... @@ -160,6 +171,8 @@ def from_native(self, data: NativeFrameT | Any, /) -> CompliantLazyFrameT: msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover raise TypeError(msg) + int_range = not_implemented() # type: ignore[misc] + class EagerNamespace( DepthTrackingNamespace[EagerDataFrameT, EagerExprT], @@ -245,3 +258,39 @@ def concat( else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) + + def int_range_eager( + self, + start: int, + end: int, + step: int = 1, + *, + dtype: IntegerDType = Int64, + name: str = "literal", + ) -> EagerSeriesT: ... + + def int_range( + self, + start: EagerExprT, + end: EagerExprT, + step: int = 1, + *, + dtype: IntegerDType = Int64, + ) -> EagerExprT: + def func(df: EagerDataFrameT) -> list[EagerSeriesT]: + start_eval = start(df)[0] + name = start_eval.name + start_value = start_eval.item() + end_value = end(df)[0].item() + return [ + self.int_range_eager(start_value, end_value, step, dtype=dtype, name=name) + ] + + return self._expr._from_callable( + func=func, + depth=0, + function_name="int_range", + evaluate_output_names=combine_evaluate_output_names(start), + alias_output_names=None, + context=self, + ) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index ef73e76340..39827d620c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -445,8 +445,7 @@ def estimated_size(self, unit: SizeUnit) -> int | float: def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: plx = self.__narwhals_namespace__() if order_by is None: - size = len(self) - data = self._array_funcs.arange(size) + data = self._array_funcs.arange(len(self)) row_index = plx._expr._from_series( plx._series.from_iterable( diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index d682add939..f3685cea77 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -16,8 +16,9 @@ from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT -from narwhals._pandas_like.utils import is_non_nullable_boolean +from narwhals._pandas_like.utils import import_array_module, is_non_nullable_boolean from narwhals._utils import zip_strict +from narwhals.dtypes import Int64 if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -26,7 +27,7 @@ from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Implementation, Version - from narwhals.typing import IntoDType, NonNestedLiteral + from narwhals.typing import IntegerDType, IntoDType, NonNestedLiteral Incomplete: TypeAlias = Any @@ -66,6 +67,14 @@ def _series(self) -> type[PandasLikeSeries]: def selectors(self) -> PandasSelectorNamespace: return PandasSelectorNamespace.from_namespace(self) + @property + def _array_funcs(self): # type: ignore[no-untyped-def] # noqa: ANN202 + if TYPE_CHECKING: + import numpy as np + + return np + return import_array_module(self._implementation) + def __init__(self, implementation: Implementation, version: Version) -> None: self._implementation = implementation self._version = version @@ -373,6 +382,18 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: context=self, ) + def int_range_eager( + self, + start: int, + end: int, + step: int = 1, + *, + dtype: IntegerDType = Int64, + name: str = "literal", + ) -> PandasLikeSeries: + data = self._array_funcs.arange(start, end, step) + return self._series.from_iterable(data, context=self, name=name, dtype=dtype) + class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]): @overload diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 8f257b36bd..20b70398e2 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -11,7 +11,7 @@ from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype from narwhals._utils import Implementation, requires, zip_strict from narwhals.dependencies import is_numpy_array_2d -from narwhals.dtypes import DType +from narwhals.dtypes import DType, Int64 if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -26,6 +26,7 @@ from narwhals.expr import Expr from narwhals.series import Series from narwhals.typing import ( + IntegerDType, Into1DArray, IntoDType, IntoSchema, @@ -230,6 +231,31 @@ def concat_str( version=self._version, ) + def int_range_eager( + self, + start: int, + end: int, + step: int = 1, + *, + dtype: IntegerDType = Int64, + name: str = "literal", + ) -> PolarsSeries: + dtype_pl = narwhals_to_native_dtype(dtype, self._version) + native = pl.int_range(start, end, step, dtype=dtype_pl, eager=True).alias(name) + return self._series.from_native(native, context=self) + + def int_range( + self, + start: PolarsExpr, + end: PolarsExpr, + step: int = 1, + *, + dtype: IntegerDType = Int64, + ) -> PolarsExpr: + pl_dtype = narwhals_to_native_dtype(dtype, self._version) + native = pl.int_range(start.native, end.native, step, dtype=pl_dtype) + return self._expr(native, self._version) + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) # 1. Others have lots of private stuff for code reuse # i. None of that is useful here diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 397016b765..371776073f 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping + from polars.datatypes import IntegerType as PlIntegerType from typing_extensions import TypeIs from narwhals._compliant.typing import Accessor @@ -35,7 +36,7 @@ from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType - from narwhals.typing import IntoDType + from narwhals.typing import IntegerDType, IntoDType T = TypeVar("T") NativeT = TypeVar( @@ -203,9 +204,11 @@ def _version_dependent_dtypes() -> dict[type[DType], pl.DataType]: UNSUPPORTED_DTYPES = (dtypes.Decimal,) -def narwhals_to_native_dtype( # noqa: C901 - dtype: IntoDType, version: Version -) -> pl.DataType: +@overload +def narwhals_to_native_dtype(dtype: IntegerDType, version: Version) -> PlIntegerType: ... +@overload +def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pl.DataType: ... +def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pl.DataType: # noqa: C901 dtypes = version.dtypes base_type = dtype.base_type() if pl_type := NW_TO_PL_DTYPES.get(base_type): diff --git a/narwhals/functions.py b/narwhals/functions.py index 56252ec6c3..a4c438b15e 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -4,7 +4,7 @@ import sys from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, overload from narwhals._expression_parsing import ( ExprKind, @@ -23,6 +23,7 @@ is_sequence_but_not_str, normalize_path, supports_arrow_c_stream, + unstable, validate_laziness, ) from narwhals.dependencies import ( @@ -31,6 +32,7 @@ is_numpy_array_2d, is_pyarrow_table, ) +from narwhals.dtypes import Int64 from narwhals.exceptions import InvalidOperationError from narwhals.expr import Expr from narwhals.series import Series @@ -49,6 +51,7 @@ ConcatMethod, FileSource, FrameT, + IntegerDType, IntoDType, IntoExpr, IntoSchema, @@ -1786,3 +1789,152 @@ def format(f_string: str, *args: IntoExpr) -> Expr: if len(s) > 0: exprs.append(lit(s)) return concat_str(exprs, separator="") + + +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: Literal[False] = ..., +) -> Expr: ... + + +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: IntoBackend[EagerAllowed], +) -> Series[Any]: ... + + +@unstable +def int_range( + start: int | Expr, + end: int | Expr | None = None, + step: int = 1, + *, + dtype: IntegerDType = Int64, + eager: IntoBackend[EagerAllowed] | Literal[False] = False, +) -> Expr | Series[Any]: + """Generate a range of integers. + + Warning: + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + start: Start of the range (inclusive). Defaults to 0. + end: End of the range (exclusive). If set to `None` (default), + the value of `start` is used and `start` is set to `0`. + step: Step size of the range. + dtype: Data type of the range (must be an integer data type). + eager: If set to `False` (default), then an expression is returned. + If set to an (eager) implementation ("pandas", "polars" or "pyarrow"), then + a `Series` is returned. + + Examples: + >>> import narwhals as nw + >>> nw.int_range(0, 5, step=2, eager="pandas") + ┌───────────────────────────┐ + | Narwhals Series | + |---------------------------| + |0 0 | + |1 2 | + |2 4 | + |Name: literal, dtype: int64| + └───────────────────────────┘ + + `end` can be omitted for a shorter syntax. + + >>> nw.int_range(5, step=2, eager="pandas") + ┌───────────────────────────┐ + | Narwhals Series | + |---------------------------| + |0 0 | + |1 2 | + |2 4 | + |Name: literal, dtype: int64| + └───────────────────────────┘ + + Generate an index column by using `int_range` in conjunction with :func:`len`. + + >>> import pandas as pd + >>> df = nw.from_native(pd.DataFrame({"a": [1, 3, 5], "b": [2, 4, 6]})) + >>> df.select(nw.int_range(nw.len(), dtype=nw.UInt32).alias("index"), nw.all()) + ┌──────────────────┐ + |Narwhals DataFrame| + |------------------| + | index a b | + | 0 0 1 2 | + | 1 1 3 4 | + | 2 2 5 6 | + └──────────────────┘ + """ + return _int_range_impl(start, end, step, dtype=dtype, eager=eager) + + +def _int_range_impl( + start: int | Expr, + end: int | Expr | None, + step: int, + *, + dtype: IntegerDType, + eager: IntoBackend[EagerAllowed] | Literal[False], +) -> Expr | Series[Any]: + from narwhals.exceptions import ComputeError + + if not dtype.is_integer(): + msg = f"non-integer `dtype` passed to `int_range`: {dtype}" + raise ComputeError(msg) + + if end is None: + end = start + start = 0 + + if not eager: + start = start if isinstance(start, Expr) else lit(start, dtype=dtype) + end = end if isinstance(end, Expr) else lit(end, dtype=dtype) + + if start._metadata.expansion_kind.is_multi_output(): + msg = "`start` must contain exactly one value, got expression returning multiple values" + raise ComputeError(msg) + + if end._metadata.expansion_kind.is_multi_output(): + msg = "`end` must contain exactly one value, got expression returning multiple values" + raise ComputeError(msg) + + args = start, end + return Expr( + lambda plx: apply_n_ary_operation( + plx, + partial(plx.int_range, step=step, dtype=dtype), + *args, + str_as_lit=False, + ), + ExprMetadata.selector_single(), + ) + + impl = Implementation.from_backend(eager) + if is_eager_allowed(impl): + if not (isinstance(start, int) and isinstance(end, int)): + msg = ( + f"Expected `start` and `end` to be integer values since `eager={eager}`.\n" + f"Found: `start` of type {type(start)} and `end` of type {type(end)}\n\n" + "Hint: Calling `nw.int_range` with expressions requires:\n" + " - `eager=False`" + " - a context such as `select` or `with_columns`" + ) + raise InvalidOperationError(msg) + + ns = Version.MAIN.namespace.from_backend(impl).compliant + series = ns.int_range_eager(start=start, end=end, step=step, dtype=dtype) + return series.to_narwhals() + + msg = f"Cannot create a Series from a lazy backend. Found: {impl}" + raise ValueError(msg) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index e97e3dc7d7..2e7961849e 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -24,7 +24,7 @@ from narwhals.dataframe import DataFrame as NwDataFrame, LazyFrame as NwLazyFrame from narwhals.exceptions import InvalidIntoExprError from narwhals.expr import Expr as NwExpr -from narwhals.functions import _new_series_impl, concat, show_versions +from narwhals.functions import _int_range_impl, _new_series_impl, concat, show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries from narwhals.stable.v1 import dependencies, dtypes, selectors @@ -81,6 +81,7 @@ from narwhals.dtypes import DType from narwhals.typing import ( FileSource, + IntegerDType, IntoDType, IntoExpr, IntoFrame, @@ -1371,6 +1372,57 @@ def scan_parquet( return _stableify(nw_f.scan_parquet(source, backend=backend, **kwargs)) +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: Literal[False] = ..., +) -> Expr: ... + + +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: IntoBackend[EagerAllowed], +) -> Series[Any]: ... + + +def int_range( + start: int | Expr, + end: int | Expr | None = None, + step: int = 1, + *, + dtype: IntegerDType = Int64, + eager: IntoBackend[EagerAllowed] | Literal[False] = False, +) -> Expr | Series[Any]: + """Generate a range of integers. + + Warning: + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + start: Start of the range (inclusive). Defaults to 0. + end: End of the range (exclusive). If set to `None` (default), + the value of `start` is used and `start` is set to `0`. + step: Step size of the range. + dtype: Data type of the range (must be an integer data type). + eager: If set to `False` (default), then an expression is returned. + If set to an (eager) implementation ("pandas", "polars" or "pyarrow"), then + a `Series` is returned. + """ + return _stableify( + _int_range_impl(start=start, end=end, step=step, dtype=dtype, eager=eager) + ) + + __all__ = [ "Array", "Binary", @@ -1426,6 +1478,7 @@ def scan_parquet( "generate_temporary_column_name", "get_level", "get_native_namespace", + "int_range", "is_ordered_categorical", "len", "lit", diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index d85b129b8a..e2459c8882 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -51,7 +51,7 @@ Unknown, ) from narwhals.expr import Expr as NwExpr -from narwhals.functions import _new_series_impl, concat, show_versions +from narwhals.functions import _int_range_impl, _new_series_impl, concat, show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries from narwhals.stable.v2 import dependencies, dtypes, selectors @@ -76,6 +76,7 @@ from narwhals.dataframe import MultiColSelector, MultiIndexSelector from narwhals.dtypes import DType from narwhals.typing import ( + IntegerDType, IntoDType, IntoExpr, IntoFrame, @@ -1196,6 +1197,57 @@ def scan_parquet( return _stableify(nw_f.scan_parquet(source, backend=backend, **kwargs)) +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: Literal[False] = ..., +) -> Expr: ... + + +@overload +def int_range( + start: int | Expr, + end: int | Expr | None = ..., + step: int = ..., + *, + dtype: IntegerDType = ..., + eager: IntoBackend[EagerAllowed], +) -> Series[Any]: ... + + +def int_range( + start: int | Expr, + end: int | Expr | None = None, + step: int = 1, + *, + dtype: IntegerDType = Int64, + eager: IntoBackend[EagerAllowed] | Literal[False] = False, +) -> Expr | Series[Any]: + """Generate a range of integers. + + Warning: + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + start: Start of the range (inclusive). Defaults to 0. + end: End of the range (exclusive). If set to `None` (default), + the value of `start` is used and `start` is set to `0`. + step: Step size of the range. + dtype: Data type of the range (must be an integer data type). + eager: If set to `False` (default), then an expression is returned. + If set to an (eager) implementation ("pandas", "polars" or "pyarrow"), then + a `Series` is returned. + """ + return _stableify( + _int_range_impl(start=start, end=end, step=step, dtype=dtype, eager=eager) + ) + + __all__ = [ "Array", "Binary", @@ -1251,6 +1303,7 @@ def scan_parquet( "from_numpy", "generate_temporary_column_name", "get_native_namespace", + "int_range", "is_ordered_categorical", "len", "lit", diff --git a/narwhals/typing.py b/narwhals/typing.py index 356af6e66f..27d0b98856 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -369,7 +369,8 @@ def Binary(self) -> type[dtypes.Binary]: ... NonNestedDType: TypeAlias = "dtypes.NumericType | dtypes.TemporalType | dtypes.String | dtypes.Boolean | dtypes.Binary | dtypes.Categorical | dtypes.Unknown | dtypes.Object" """Any Narwhals DType that does not have required arguments.""" - +IntegerDType: TypeAlias = "dtypes.IntegerType | type[dtypes.IntegerType]" +"""Any signed or unsigned integer DType.""" IntoDType: TypeAlias = "dtypes.DType | type[NonNestedDType]" """Anything that can be converted into a Narwhals DType. diff --git a/tests/expr_and_series/int_range_test.py b/tests/expr_and_series/int_range_test.py new file mode 100644 index 0000000000..7b6b6b6637 --- /dev/null +++ b/tests/expr_and_series/int_range_test.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from narwhals import Implementation +from narwhals.exceptions import ComputeError, InvalidOperationError +from tests.utils import Constructor, assert_equal_data + +if TYPE_CHECKING: + from narwhals.dtypes import DType, IntegerType + from narwhals.typing import EagerAllowed + + +@pytest.mark.parametrize( + ("start", "end", "step", "dtype"), + [ + (0, 0, 1, nw.UInt8()), + (0, 3, 1, nw.UInt16), + (-3, 0, -1, nw.Int16()), + (0, 3, 2, nw.Int64), + (3, None, 1, nw.UInt32), + (3, None, 2, nw.Int8()), + ], +) +def test_int_range_eager( + start: int, + end: int | None, + step: int, + dtype: type[IntegerType] | IntegerType, + eager_implementation: EagerAllowed, +) -> None: + pytest.importorskip(str(eager_implementation)) + series = nw.int_range(start, end, step, dtype=dtype, eager=eager_implementation) + + assert series.dtype == dtype + if end is None: + end = start + start = 0 + assert_equal_data({"a": series}, {"a": list(range(start, end, step))}) + + +def test_int_range_eager_expr_raises(eager_implementation: EagerAllowed) -> None: + msg = "Expected `start` and `end` to be integer values" + with pytest.raises(InvalidOperationError, match=msg): + nw.int_range(nw.col("a").min(), nw.col("a").max() * 2, eager=eager_implementation) + + +@pytest.mark.parametrize( + ("start", "end", "step", "dtype", "expected"), + [ + (0, nw.len(), 1, nw.UInt8(), [0, 1, 2]), + (0, 3, 1, nw.UInt16, [0, 1, 2]), + (-3, nw.len() - 3, 1, nw.Int16(), [-3, -2, -1]), + (nw.len(), 0, -1, nw.Int64, [3, 2, 1]), + (nw.len(), None, 1, nw.UInt32, [0, 1, 2]), + ], +) +def test_int_range_lazy( + request: pytest.FixtureRequest, + constructor: Constructor, + start: int, + end: int | None, + step: int, + dtype: type[IntegerType] | IntegerType, + expected: list[int], +) -> None: + if any(x in str(constructor) for x in ("dask", "duckdb", "ibis", "spark")): + reason = "not implemented yet" + request.applymarker(pytest.mark.xfail(reason=reason)) + + data = {"a": ["foo", "bar", "baz"]} + frame = nw.from_native(constructor(data)) + result = frame.select(nw.int_range(start, end, step, dtype=dtype)) + + output_name = "len" if isinstance(start, nw.Expr) and end is not None else "literal" + assert_equal_data(result, {output_name: expected}) + assert result.collect_schema()[output_name] == dtype + + +@pytest.mark.parametrize( + "dtype", [nw.List, nw.Float64(), nw.Float32, nw.Decimal, nw.String()] +) +def test_int_range_non_int_dtype(dtype: DType) -> None: + msg = f"non-integer `dtype` passed to `int_range`: {dtype}" + with pytest.raises(ComputeError, match=msg): + nw.int_range(start=0, end=3, dtype=dtype) # type: ignore[call-overload] + + +@pytest.mark.parametrize( + ("start", "end"), + [ + (nw.col("foo", "bar").sum(), nw.col("foo", "bar").sum()), + (1, nw.col("foo", "bar").sum()), + ], +) +def test_int_range_multi_named(start: int | nw.Expr, end: int | nw.Expr | None) -> None: + prefix = "`start`" if isinstance(start, nw.Expr) else "`end`" + msg = f"{prefix} must contain exactly one value, got expression returning multiple values" + with pytest.raises(ComputeError, match=msg): + nw.int_range(start=start, end=end) + + +def test_int_range_eager_set_to_lazy_backend() -> None: + with pytest.raises(ValueError, match="Cannot create a Series from a lazy backend"): + nw.int_range(123, eager=Implementation.DUCKDB) # type: ignore[call-overload] diff --git a/tests/v1_test.py b/tests/v1_test.py index 5ef46423ea..42b38a53b1 100644 --- a/tests/v1_test.py +++ b/tests/v1_test.py @@ -1112,3 +1112,15 @@ def test_mode_different_lengths(constructor_eager: ConstructorEager) -> None: df = nw_v1.from_native(constructor_eager({"a": [1, 1, 2], "b": [4, 5, 6]})) with pytest.raises(ShapeError): df.select(nw_v1.col("a", "b").mode()) + + +def test_int_range() -> None: + pytest.importorskip("pandas") + + def minimal_function(data: nw_v1.Series[Any]) -> None: + data.is_null() + + col = nw_v1.int_range(0, 3, eager="pandas") + # check this doesn't raise type-checking errors + minimal_function(col) + assert isinstance(col, nw_v1.Series) diff --git a/tests/v2_test.py b/tests/v2_test.py index 8a727bbf55..31509a9065 100644 --- a/tests/v2_test.py +++ b/tests/v2_test.py @@ -501,3 +501,15 @@ def test_mode_different_lengths(constructor_eager: ConstructorEager) -> None: df = nw_v2.from_native(constructor_eager({"a": [1, 1, 2], "b": [4, 5, 6]})) with pytest.raises(ShapeError): df.select(nw_v2.col("a", "b").mode()) + + +def test_int_range() -> None: + pytest.importorskip("pandas") + + def minimal_function(data: nw_v2.Series[Any]) -> None: + data.is_null() + + col = nw_v2.int_range(0, 3, eager="pandas") + # check this doesn't raise type-checking errors + minimal_function(col) + assert isinstance(col, nw_v2.Series)