Skip to content

RFC: Introduce pandas.col #62103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
Series,
DataFrame,
)
from pandas.core.col import col

from pandas.core.dtypes.dtypes import SparseDtype

Expand Down Expand Up @@ -281,6 +282,7 @@
"array",
"arrays",
"bdate_range",
"col",
"concat",
"crosstab",
"cut",
Expand Down
246 changes: 246 additions & 0 deletions pandas/core/col.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from __future__ import annotations

from collections.abc import (
Callable,
Hashable,
)
from typing import (
TYPE_CHECKING,
Any,
)

from pandas.core.dtypes.common import is_scalar

from pandas.core.series import Series

if TYPE_CHECKING:
from pandas import DataFrame


OP_SYMBOLS = {
"__add__": "+",
"__radd__": "+",
"__sub__": "-",
"__rsub__": "-",
"__mul__": "*",
"__rmul__": "*",
"__truediv__": "/",
"__rtruediv__": "/",
"__floordiv__": "//",
"__rfloordiv__": "//",
"__ge__": ">=",
"__gt__": ">",
"__le__": "<=",
"__lt__": "<",
"__eq__": "==",
"__ne__": "!=",
"__mod__": "%",
}


def parse_args(df: DataFrame, *args: Any) -> tuple[Series]:
return tuple([x._func(df) if isinstance(x, Expr) else x for x in args])


def parse_kwargs(df: DataFrame, **kwargs: Any) -> dict[Hashable, Series]:
return {
key: val._func(df) if isinstance(val, Expr) else val
for key, val in kwargs.items()
}


class Expr:
def __init__(
self, func: Callable[[DataFrame], Any], repr_str: str | None = None
) -> None:
self._func = func
self._repr_str = repr_str

def __call__(self, df: DataFrame) -> Series:
result = self._func(df)
if not (isinstance(result, Series) or is_scalar(result)):
msg = (
"Expected function which returns Series or scalar, "
f"got function which returns: {type(result)}"
)
raise TypeError(msg)
return result

def _with_binary_op(self, op: str, other: Any) -> Expr:
op_symbol = OP_SYMBOLS.get(op, op)

if isinstance(other, Expr):
if op.startswith("__r"):
repr_str = f"({other._repr_str} {op_symbol} {self._repr_str})"
else:
repr_str = f"({self._repr_str} {op_symbol} {other._repr_str})"
return Expr(
lambda df: getattr(self._func(df), op)(other._func(df)), repr_str
)
else:
if op.startswith("__r"):
repr_str = f"({other!r} {op_symbol} {self._repr_str})"
else:
repr_str = f"({self._repr_str} {op_symbol} {other!r})"
return Expr(lambda df: getattr(self._func(df), op)(other), repr_str)

# Binary ops
def __add__(self, other: Any) -> Expr:
return self._with_binary_op("__add__", other)

def __radd__(self, other: Any) -> Expr:
return self._with_binary_op("__radd__", other)

def __sub__(self, other: Any) -> Expr:
return self._with_binary_op("__sub__", other)

def __rsub__(self, other: Any) -> Expr:
return self._with_binary_op("__rsub__", other)

def __mul__(self, other: Any) -> Expr:
return self._with_binary_op("__mul__", other)

def __rmul__(self, other: Any) -> Expr:
return self._with_binary_op("__rmul__", other)

def __truediv__(self, other: Any) -> Expr:
return self._with_binary_op("__truediv__", other)

def __rtruediv__(self, other: Any) -> Expr:
return self._with_binary_op("__rtruediv__", other)

def __floordiv__(self, other: Any) -> Expr:
return self._with_binary_op("__floordiv__", other)

def __rfloordiv__(self, other: Any) -> Expr:
return self._with_binary_op("__rfloordiv__", other)

def __ge__(self, other: Any) -> Expr:
return self._with_binary_op("__ge__", other)

def __gt__(self, other: Any) -> Expr:
return self._with_binary_op("__gt__", other)

def __le__(self, other: Any) -> Expr:
return self._with_binary_op("__le__", other)

def __lt__(self, other: Any) -> Expr:
return self._with_binary_op("__lt__", other)

def __eq__(self, other: object) -> Expr: # type: ignore[override]
return self._with_binary_op("__eq__", other)

def __ne__(self, other: object) -> Expr: # type: ignore[override]
return self._with_binary_op("__ne__", other)

def __mod__(self, other: Any) -> Expr:
return self._with_binary_op("__mod__", other)

# Everything else
def __getattr__(self, attr: str, /) -> Callable[..., Expr]:
def func(df: DataFrame, *args: Any, **kwargs: Any) -> Any:
parsed_args = parse_args(df, *args)
parsed_kwargs = parse_kwargs(df, **kwargs)
return getattr(self(df), attr)(*parsed_args, **parsed_kwargs)

def wrapper(*args: Any, **kwargs: Any) -> Expr:
# Create a readable representation for method calls
args_repr = ", ".join(
repr(arg._repr_str if isinstance(arg, Expr) else arg) for arg in args
)
kwargs_repr = ", ".join(
f"{k}={v._repr_str if isinstance(v, Expr) else v!r}"
for k, v in kwargs.items()
)

all_args = []
if args_repr:
all_args.append(args_repr)
if kwargs_repr:
all_args.append(kwargs_repr)

args_str = ", ".join(all_args)
repr_str = f"{self._repr_str}.{attr}({args_str})"

return Expr(lambda df: func(df, *args, **kwargs), repr_str)

return wrapper

def __repr__(self) -> str:
return self._repr_str or "Expr(...)"

# Namespaces
@property
def dt(self) -> NamespaceExpr:
return NamespaceExpr(self, "dt")

@property
def str(self) -> NamespaceExpr:
return NamespaceExpr(self, "str")

@property
def cat(self) -> NamespaceExpr:
return NamespaceExpr(self, "cat")

@property
def list(self) -> NamespaceExpr:
return NamespaceExpr(self, "list")

@property
def sparse(self) -> NamespaceExpr:
return NamespaceExpr(self, "sparse")

@property
def struct(self) -> NamespaceExpr:
return NamespaceExpr(self, "struct")


class NamespaceExpr:
def __init__(self, func: Expr, namespace: str) -> None:
self._func = func
self._namespace = namespace

def __getattr__(self, attr: str) -> Any:
if isinstance(getattr(getattr(Series, self._namespace), attr), property):
repr_str = f"{self._func._repr_str}.{self._namespace}.{attr}"
return Expr(
lambda df: getattr(getattr(self._func(df), self._namespace), attr),
repr_str,
)

def func(df: DataFrame, *args: Any, **kwargs: Any) -> Any:
parsed_args = parse_args(df, *args)
parsed_kwargs = parse_kwargs(df, **kwargs)
return getattr(getattr(self._func(df), self._namespace), attr)(
*parsed_args, **parsed_kwargs
)

def wrapper(*args: Any, **kwargs: Any) -> Expr:
# Create a readable representation for namespace method calls
args_repr = ", ".join(
repr(arg._repr_str if isinstance(arg, Expr) else arg) for arg in args
)
kwargs_repr = ", ".join(
f"{k}={v._repr_str if isinstance(v, Expr) else v!r}"
for k, v in kwargs.items()
)

all_args = []
if args_repr:
all_args.append(args_repr)
if kwargs_repr:
all_args.append(kwargs_repr)

args_str = ", ".join(all_args)
repr_str = f"{self._func._repr_str}.{self._namespace}.{attr}({args_str})"

return Expr(lambda df: func(df, *args, **kwargs), repr_str)

return wrapper


def col(col_name: Hashable) -> Expr:
if not isinstance(col_name, Hashable):
msg = f"Expected Hashable, got: {type(col_name)}"
raise TypeError(msg)
return Expr(lambda df: df[col_name], f"col({col_name!r})")
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TestPDApi(Base):
funcs = [
"array",
"bdate_range",
"col",
"concat",
"crosstab",
"cut",
Expand Down
Loading