Skip to content

Commit afddd4c

Browse files
committed
refactor(DRAFT): Possibly all ArrowExpr?
Main source of errors atm are expr namespaces - still needs migrating
1 parent 7a1a653 commit afddd4c

File tree

10 files changed

+85
-122
lines changed

10 files changed

+85
-122
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def unique(
778778

779779
keep_idx = self.simple_select(*subset).is_unique()
780780
plx = self.__narwhals_namespace__()
781-
return self.filter(plx._create_expr_from_series(keep_idx))
781+
return self.filter(plx._expr._from_series(keep_idx))
782782

783783
def gather_every(self: Self, n: int, offset: int) -> Self:
784784
return self._from_native_frame(

narwhals/_arrow/expr.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from narwhals._arrow.namespace import ArrowNamespace
3030
from narwhals.dtypes import DType
3131
from narwhals.utils import Version
32+
from narwhals.utils import _FullContext
3233

3334

3435
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
@@ -45,6 +46,7 @@ def __init__(
4546
backend_version: tuple[int, ...],
4647
version: Version,
4748
call_kwargs: dict[str, Any] | None = None,
49+
implementation: Implementation | None = None,
4850
) -> None:
4951
self._call = call
5052
self._depth = depth
@@ -63,8 +65,7 @@ def from_column_names(
6365
/,
6466
*,
6567
function_name: str,
66-
backend_version: tuple[int, ...],
67-
version: Version,
68+
context: _FullContext,
6869
) -> Self:
6970
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
7071
try:
@@ -91,16 +92,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
9192
function_name=function_name,
9293
evaluate_output_names=evaluate_column_names,
9394
alias_output_names=None,
94-
backend_version=backend_version,
95-
version=version,
95+
backend_version=context._backend_version,
96+
version=context._version,
9697
)
9798

9899
@classmethod
99100
def from_column_indices(
100-
cls: type[Self],
101-
*column_indices: int,
102-
backend_version: tuple[int, ...],
103-
version: Version,
101+
cls: type[Self], *column_indices: int, context: _FullContext
104102
) -> Self:
105103
from narwhals._arrow.series import ArrowSeries
106104

@@ -121,8 +119,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
121119
function_name="nth",
122120
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
123121
alias_output_names=None,
124-
backend_version=backend_version,
125-
version=version,
122+
backend_version=context._backend_version,
123+
version=context._version,
126124
)
127125

128126
def __narwhals_namespace__(self: Self) -> ArrowNamespace:

narwhals/_arrow/namespace.py

Lines changed: 30 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -57,59 +57,7 @@ def _expr(self) -> type[ArrowExpr]:
5757
def _series(self) -> type[ArrowSeries]:
5858
return ArrowSeries
5959

60-
def _create_expr_from_callable(
61-
self: Self,
62-
func: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
63-
*,
64-
depth: int,
65-
function_name: str,
66-
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
67-
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
68-
call_kwargs: dict[str, Any] | None = None,
69-
) -> ArrowExpr:
70-
from narwhals._arrow.expr import ArrowExpr
71-
72-
return ArrowExpr(
73-
func,
74-
depth=depth,
75-
function_name=function_name,
76-
evaluate_output_names=evaluate_output_names,
77-
alias_output_names=alias_output_names,
78-
backend_version=self._backend_version,
79-
version=self._version,
80-
call_kwargs=call_kwargs,
81-
)
82-
83-
def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr:
84-
from narwhals._arrow.expr import ArrowExpr
85-
86-
return ArrowExpr(
87-
lambda _df: [series],
88-
depth=0,
89-
function_name="series",
90-
evaluate_output_names=lambda _df: [series.name],
91-
alias_output_names=None,
92-
backend_version=self._backend_version,
93-
version=self._version,
94-
)
95-
96-
def _create_series_from_scalar(
97-
self: Self, value: Any, *, reference_series: ArrowSeries
98-
) -> ArrowSeries:
99-
from narwhals._arrow.series import ArrowSeries
100-
101-
if self._backend_version < (13,) and hasattr(value, "as_py"):
102-
value = value.as_py()
103-
return ArrowSeries._from_iterable(
104-
[value],
105-
name=reference_series.name,
106-
backend_version=self._backend_version,
107-
version=self._version,
108-
)
109-
11060
def _create_compliant_series(self: Self, value: Any) -> ArrowSeries:
111-
from narwhals._arrow.series import ArrowSeries
112-
11361
return ArrowSeries(
11462
native_series=pa.chunked_array([value]),
11563
name="",
@@ -127,39 +75,26 @@ def __init__(
12775

12876
# --- selection ---
12977
def col(self: Self, *column_names: str) -> ArrowExpr:
130-
from narwhals._arrow.expr import ArrowExpr
131-
132-
return ArrowExpr.from_column_names(
133-
passthrough_column_names(column_names),
134-
function_name="col",
135-
backend_version=self._backend_version,
136-
version=self._version,
78+
return self._expr.from_column_names(
79+
passthrough_column_names(column_names), function_name="col", context=self
13780
)
13881

13982
def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
140-
return ArrowExpr.from_column_names(
83+
return self._expr.from_column_names(
14184
partial(exclude_column_names, names=excluded_names),
14285
function_name="exclude",
143-
backend_version=self._backend_version,
144-
version=self._version,
86+
context=self,
14587
)
14688

14789
def nth(self: Self, *column_indices: int) -> ArrowExpr:
148-
from narwhals._arrow.expr import ArrowExpr
149-
150-
return ArrowExpr.from_column_indices(
151-
*column_indices, backend_version=self._backend_version, version=self._version
152-
)
90+
return self._expr.from_column_indices(*column_indices, context=self)
15391

15492
def len(self: Self) -> ArrowExpr:
15593
# coverage bug? this is definitely hit
156-
return ArrowExpr( # pragma: no cover
94+
return self._expr( # pragma: no cover
15795
lambda df: [
15896
ArrowSeries._from_iterable(
159-
[len(df._native_frame)],
160-
name="len",
161-
backend_version=self._backend_version,
162-
version=self._version,
97+
[len(df._native_frame)], name="len", context=self
16398
)
16499
],
165100
depth=0,
@@ -171,26 +106,20 @@ def len(self: Self) -> ArrowExpr:
171106
)
172107

173108
def all(self: Self) -> ArrowExpr:
174-
return ArrowExpr.from_column_names(
175-
get_column_names,
176-
function_name="all",
177-
backend_version=self._backend_version,
178-
version=self._version,
109+
return self._expr.from_column_names(
110+
get_column_names, function_name="all", context=self
179111
)
180112

181113
def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
182114
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
183115
arrow_series = ArrowSeries._from_iterable(
184-
data=[value],
185-
name="literal",
186-
backend_version=self._backend_version,
187-
version=self._version,
116+
data=[value], name="literal", context=self
188117
)
189118
if dtype:
190119
return arrow_series.cast(dtype)
191120
return arrow_series
192121

193-
return ArrowExpr(
122+
return self._expr(
194123
lambda df: [_lit_arrow_series(df)],
195124
depth=0,
196125
function_name="lit",
@@ -200,30 +129,34 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
200129
version=self._version,
201130
)
202131

203-
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
132+
# NOTE: Needs to be resolved in `EagerNamespace`
133+
# Probably, by adding an `EagerExprT` typevar
134+
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: # type: ignore[override]
204135
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
205136
series = chain.from_iterable(expr(df) for expr in exprs)
206137
return [reduce(operator.and_, align_series_full_broadcast(*series))]
207138

208-
return self._create_expr_from_callable(
139+
return self._expr._from_callable(
209140
func=func,
210141
depth=max(x._depth for x in exprs) + 1,
211142
function_name="all_horizontal",
212143
evaluate_output_names=combine_evaluate_output_names(*exprs),
213144
alias_output_names=combine_alias_output_names(*exprs),
145+
context=self,
214146
)
215147

216148
def any_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
217149
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
218150
series = chain.from_iterable(expr(df) for expr in exprs)
219151
return [reduce(operator.or_, align_series_full_broadcast(*series))]
220152

221-
return self._create_expr_from_callable(
153+
return self._expr._from_callable(
222154
func=func,
223155
depth=max(x._depth for x in exprs) + 1,
224156
function_name="any_horizontal",
225157
evaluate_output_names=combine_evaluate_output_names(*exprs),
226158
alias_output_names=combine_alias_output_names(*exprs),
159+
context=self,
227160
)
228161

229162
def sum_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
@@ -232,12 +165,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
232165
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
233166
return [reduce(operator.add, align_series_full_broadcast(*series))]
234167

235-
return self._create_expr_from_callable(
168+
return self._expr._from_callable(
236169
func=func,
237170
depth=max(x._depth for x in exprs) + 1,
238171
function_name="sum_horizontal",
239172
evaluate_output_names=combine_evaluate_output_names(*exprs),
240173
alias_output_names=combine_alias_output_names(*exprs),
174+
context=self,
241175
)
242176

243177
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
@@ -253,12 +187,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
253187
)
254188
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
255189

256-
return self._create_expr_from_callable(
190+
return self._expr._from_callable(
257191
func=func,
258192
depth=max(x._depth for x in exprs) + 1,
259193
function_name="mean_horizontal",
260194
evaluate_output_names=combine_evaluate_output_names(*exprs),
261195
alias_output_names=combine_alias_output_names(*exprs),
196+
context=self,
262197
)
263198

264199
def min_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
@@ -281,12 +216,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
281216
)
282217
]
283218

284-
return self._create_expr_from_callable(
219+
return self._expr._from_callable(
285220
func=func,
286221
depth=max(x._depth for x in exprs) + 1,
287222
function_name="min_horizontal",
288223
evaluate_output_names=combine_evaluate_output_names(*exprs),
289224
alias_output_names=combine_alias_output_names(*exprs),
225+
context=self,
290226
)
291227

292228
def max_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
@@ -310,12 +246,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
310246
)
311247
]
312248

313-
return self._create_expr_from_callable(
249+
return self._expr._from_callable(
314250
func=func,
315251
depth=max(x._depth for x in exprs) + 1,
316252
function_name="max_horizontal",
317253
evaluate_output_names=combine_evaluate_output_names(*exprs),
318254
alias_output_names=combine_alias_output_names(*exprs),
255+
context=self,
319256
)
320257

321258
def concat(
@@ -381,12 +318,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
381318
)
382319
]
383320

384-
return self._create_expr_from_callable(
321+
return self._expr._from_callable(
385322
func=func,
386323
depth=max(x._depth for x in exprs) + 1,
387324
function_name="concat_str",
388325
evaluate_output_names=combine_evaluate_output_names(*exprs),
389326
alias_output_names=combine_alias_output_names(*exprs),
327+
context=self,
390328
)
391329

392330

@@ -407,16 +345,13 @@ def __init__(
407345
self._version = version
408346

409347
def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
410-
plx = df.__narwhals_namespace__()
411348
condition = self._condition(df)[0]
412349
condition_native = condition._native_series
413350

414351
if isinstance(self._then_value, ArrowExpr):
415352
value_series = self._then_value(df)[0]
416353
else:
417-
value_series = plx._create_series_from_scalar(
418-
self._then_value, reference_series=condition.alias("literal")
419-
)
354+
value_series = condition.alias("literal")._from_scalar(self._then_value)
420355
value_series._broadcast = True
421356
value_series_native = extract_dataframe_comparand(
422357
len(df), value_series, self._backend_version
@@ -474,6 +409,7 @@ def __init__(
474409
backend_version: tuple[int, ...],
475410
version: Version,
476411
call_kwargs: dict[str, Any] | None = None,
412+
implementation: Implementation | None = None,
477413
) -> None:
478414
self._backend_version = backend_version
479415
self._version = version

narwhals/_arrow/series.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from narwhals.typing import _1DArray
5555
from narwhals.typing import _2DArray
5656
from narwhals.utils import Version
57+
from narwhals.utils import _FullContext
5758

5859

5960
# TODO @dangotbanned: move into `_arrow.utils`
@@ -140,16 +141,20 @@ def _from_iterable(
140141
data: Iterable[Any],
141142
name: str,
142143
*,
143-
backend_version: tuple[int, ...],
144-
version: Version,
144+
context: _FullContext,
145145
) -> Self:
146146
return cls(
147147
chunked_array([data]),
148148
name=name,
149-
backend_version=backend_version,
150-
version=version,
149+
backend_version=context._backend_version,
150+
version=context._version,
151151
)
152152

153+
def _from_scalar(self, value: Any) -> Self:
154+
if self._backend_version < (13,) and hasattr(value, "as_py"):
155+
value = value.as_py()
156+
return super()._from_scalar(value)
157+
153158
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
154159
from narwhals._arrow.namespace import ArrowNamespace
155160

@@ -570,12 +575,7 @@ def arg_true(self: Self) -> Self:
570575

571576
ser = self._native_series
572577
res = np.flatnonzero(ser)
573-
return self._from_iterable(
574-
res,
575-
name=self.name,
576-
backend_version=self._backend_version,
577-
version=self._version,
578-
)
578+
return self._from_iterable(res, name=self.name, context=self)
579579

580580
def item(self: Self, index: int | None = None) -> Any:
581581
if index is None:

0 commit comments

Comments
 (0)