Skip to content

Commit 067252b

Browse files
committed
refactor: Simplify EagerExprNameNamespace
1 parent b05f132 commit 067252b

File tree

2 files changed

+36
-53
lines changed

2 files changed

+36
-53
lines changed

narwhals/_compliant/expr.py

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from narwhals._compliant.any_namespace import NameNamespace
1919
from narwhals._compliant.any_namespace import StringNamespace
2020
from narwhals._compliant.namespace import CompliantNamespace
21+
from narwhals._compliant.typing import AliasName
22+
from narwhals._compliant.typing import AliasNames
2123
from narwhals._compliant.typing import CompliantFrameT
2224
from narwhals._compliant.typing import CompliantLazyFrameT
2325
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
@@ -908,70 +910,45 @@ class EagerExprNameNamespace(
908910
EagerExprNamespace[EagerExprT], NameNamespace[EagerExprT], Generic[EagerExprT]
909911
):
910912
def keep(self) -> EagerExprT:
911-
return self._from_colname_func_and_alias_output_names(
912-
name_mapping_func=lambda name: name
913-
)
913+
return self._from_callable(lambda name: name, alias=False)
914914

915-
def map(self, function: Callable[[str], str]) -> EagerExprT:
916-
return self._from_colname_func_and_alias_output_names(
917-
name_mapping_func=function,
918-
alias_output_names=lambda output_names: [
919-
function(name) for name in output_names
920-
],
921-
)
915+
def map(self, function: AliasName) -> EagerExprT:
916+
return self._from_callable(function)
922917

923918
def prefix(self, prefix: str) -> EagerExprT:
924-
return self._from_colname_func_and_alias_output_names(
925-
name_mapping_func=lambda name: f"{prefix}{name}",
926-
alias_output_names=lambda output_names: [
927-
f"{prefix}{output_name}" for output_name in output_names
928-
],
929-
)
919+
return self._from_callable(lambda name: f"{prefix}{name}")
930920

931921
def suffix(self, suffix: str) -> EagerExprT:
932-
return self._from_colname_func_and_alias_output_names(
933-
name_mapping_func=lambda name: f"{name}{suffix}",
934-
alias_output_names=lambda output_names: [
935-
f"{output_name}{suffix}" for output_name in output_names
936-
],
937-
)
922+
return self._from_callable(lambda name: f"{name}{suffix}")
938923

939924
def to_lowercase(self) -> EagerExprT:
940-
return self._from_colname_func_and_alias_output_names(
941-
name_mapping_func=str.lower,
942-
alias_output_names=lambda output_names: [
943-
name.lower() for name in output_names
944-
],
945-
)
925+
return self._from_callable(str.lower)
946926

947927
def to_uppercase(self) -> EagerExprT:
948-
return self._from_colname_func_and_alias_output_names(
949-
name_mapping_func=str.upper,
950-
alias_output_names=lambda output_names: [
951-
name.upper() for name in output_names
952-
],
953-
)
928+
return self._from_callable(str.upper)
954929

955-
def _from_colname_func_and_alias_output_names(
956-
self,
957-
name_mapping_func: Callable[[str], str],
958-
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None = None,
959-
) -> EagerExprT:
960-
return type(self.compliant)(
961-
call=lambda df: [
962-
series.alias(name_mapping_func(name))
963-
for series, name in zip(
964-
self.compliant._call(df), self.compliant._evaluate_output_names(df)
965-
)
930+
@staticmethod
931+
def _alias_output_names(func: AliasName, /) -> AliasNames:
932+
def fn(output_names: Sequence[str], /) -> Sequence[str]:
933+
return [func(name) for name in output_names]
934+
935+
return fn
936+
937+
def _from_callable(self, func: AliasName, /, *, alias: bool = True) -> EagerExprT:
938+
expr = self.compliant
939+
return type(expr)(
940+
lambda df: [
941+
series.alias(func(name))
942+
for series, name in zip(expr(df), expr._evaluate_output_names(df))
966943
],
967-
depth=self.compliant._depth,
968-
function_name=self.compliant._function_name,
969-
evaluate_output_names=self.compliant._evaluate_output_names,
970-
alias_output_names=alias_output_names,
971-
backend_version=self.compliant._backend_version,
972-
implementation=self.compliant._implementation,
973-
version=self.compliant._version,
974-
call_kwargs=self.compliant._call_kwargs,
944+
depth=expr._depth,
945+
function_name=expr._function_name,
946+
evaluate_output_names=expr._evaluate_output_names,
947+
alias_output_names=self._alias_output_names(func) if alias else None,
948+
backend_version=expr._backend_version,
949+
implementation=expr._implementation,
950+
version=expr._version,
951+
call_kwargs=expr._call_kwargs,
975952
)
976953

977954

narwhals/_compliant/typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Callable
6+
from typing import Sequence
57
from typing import TypeVar
68

79
if TYPE_CHECKING:
@@ -17,6 +19,8 @@
1719
from narwhals._compliant.series import EagerSeries
1820

1921
__all__ = [
22+
"AliasName",
23+
"AliasNames",
2024
"CompliantDataFrameT",
2125
"CompliantFrameT",
2226
"CompliantLazyFrameT",
@@ -43,3 +47,5 @@
4347
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")
4448
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound="EagerSeries[Any]", covariant=True)
4549
EagerExprT = TypeVar("EagerExprT", bound="EagerExpr[Any, Any]")
50+
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
51+
AliasName: TypeAlias = Callable[[str], str]

0 commit comments

Comments
 (0)