@@ -57,59 +57,7 @@ def _expr(self) -> type[ArrowExpr]:
5757 def _series (self ) -> type [ArrowSeries ]:
5858 return ArrowSeries
5959
60- def _create_expr_from_callable (
61- self : Self ,
62- func : Callable [[ArrowDataFrame ], Sequence [ArrowSeries ]],
63- * ,
64- depth : int ,
65- function_name : str ,
66- evaluate_output_names : Callable [[ArrowDataFrame ], Sequence [str ]],
67- alias_output_names : Callable [[Sequence [str ]], Sequence [str ]] | None ,
68- call_kwargs : dict [str , Any ] | None = None ,
69- ) -> ArrowExpr :
70- from narwhals ._arrow .expr import ArrowExpr
71-
72- return ArrowExpr (
73- func ,
74- depth = depth ,
75- function_name = function_name ,
76- evaluate_output_names = evaluate_output_names ,
77- alias_output_names = alias_output_names ,
78- backend_version = self ._backend_version ,
79- version = self ._version ,
80- call_kwargs = call_kwargs ,
81- )
82-
83- def _create_expr_from_series (self : Self , series : ArrowSeries ) -> ArrowExpr :
84- from narwhals ._arrow .expr import ArrowExpr
85-
86- return ArrowExpr (
87- lambda _df : [series ],
88- depth = 0 ,
89- function_name = "series" ,
90- evaluate_output_names = lambda _df : [series .name ],
91- alias_output_names = None ,
92- backend_version = self ._backend_version ,
93- version = self ._version ,
94- )
95-
96- def _create_series_from_scalar (
97- self : Self , value : Any , * , reference_series : ArrowSeries
98- ) -> ArrowSeries :
99- from narwhals ._arrow .series import ArrowSeries
100-
101- if self ._backend_version < (13 ,) and hasattr (value , "as_py" ):
102- value = value .as_py ()
103- return ArrowSeries ._from_iterable (
104- [value ],
105- name = reference_series .name ,
106- backend_version = self ._backend_version ,
107- version = self ._version ,
108- )
109-
11060 def _create_compliant_series (self : Self , value : Any ) -> ArrowSeries :
111- from narwhals ._arrow .series import ArrowSeries
112-
11361 return ArrowSeries (
11462 native_series = pa .chunked_array ([value ]),
11563 name = "" ,
@@ -127,39 +75,26 @@ def __init__(
12775
12876 # --- selection ---
12977 def col (self : Self , * column_names : str ) -> ArrowExpr :
130- from narwhals ._arrow .expr import ArrowExpr
131-
132- return ArrowExpr .from_column_names (
133- passthrough_column_names (column_names ),
134- function_name = "col" ,
135- backend_version = self ._backend_version ,
136- version = self ._version ,
78+ return self ._expr .from_column_names (
79+ passthrough_column_names (column_names ), function_name = "col" , context = self
13780 )
13881
13982 def exclude (self : Self , excluded_names : Container [str ]) -> ArrowExpr :
140- return ArrowExpr .from_column_names (
83+ return self . _expr .from_column_names (
14184 partial (exclude_column_names , names = excluded_names ),
14285 function_name = "exclude" ,
143- backend_version = self ._backend_version ,
144- version = self ._version ,
86+ context = self ,
14587 )
14688
14789 def nth (self : Self , * column_indices : int ) -> ArrowExpr :
148- from narwhals ._arrow .expr import ArrowExpr
149-
150- return ArrowExpr .from_column_indices (
151- * column_indices , backend_version = self ._backend_version , version = self ._version
152- )
90+ return self ._expr .from_column_indices (* column_indices , context = self )
15391
15492 def len (self : Self ) -> ArrowExpr :
15593 # coverage bug? this is definitely hit
156- return ArrowExpr ( # pragma: no cover
94+ return self . _expr ( # pragma: no cover
15795 lambda df : [
15896 ArrowSeries ._from_iterable (
159- [len (df ._native_frame )],
160- name = "len" ,
161- backend_version = self ._backend_version ,
162- version = self ._version ,
97+ [len (df ._native_frame )], name = "len" , context = self
16398 )
16499 ],
165100 depth = 0 ,
@@ -171,26 +106,20 @@ def len(self: Self) -> ArrowExpr:
171106 )
172107
173108 def all (self : Self ) -> ArrowExpr :
174- return ArrowExpr .from_column_names (
175- get_column_names ,
176- function_name = "all" ,
177- backend_version = self ._backend_version ,
178- version = self ._version ,
109+ return self ._expr .from_column_names (
110+ get_column_names , function_name = "all" , context = self
179111 )
180112
181113 def lit (self : Self , value : Any , dtype : DType | None ) -> ArrowExpr :
182114 def _lit_arrow_series (_ : ArrowDataFrame ) -> ArrowSeries :
183115 arrow_series = ArrowSeries ._from_iterable (
184- data = [value ],
185- name = "literal" ,
186- backend_version = self ._backend_version ,
187- version = self ._version ,
116+ data = [value ], name = "literal" , context = self
188117 )
189118 if dtype :
190119 return arrow_series .cast (dtype )
191120 return arrow_series
192121
193- return ArrowExpr (
122+ return self . _expr (
194123 lambda df : [_lit_arrow_series (df )],
195124 depth = 0 ,
196125 function_name = "lit" ,
@@ -200,30 +129,34 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
200129 version = self ._version ,
201130 )
202131
203- def all_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr :
132+ # NOTE: Needs to be resolved in `EagerNamespace`
133+ # Probably, by adding an `EagerExprT` typevar
134+ def all_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr : # type: ignore[override]
204135 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
205136 series = chain .from_iterable (expr (df ) for expr in exprs )
206137 return [reduce (operator .and_ , align_series_full_broadcast (* series ))]
207138
208- return self ._create_expr_from_callable (
139+ return self ._expr . _from_callable (
209140 func = func ,
210141 depth = max (x ._depth for x in exprs ) + 1 ,
211142 function_name = "all_horizontal" ,
212143 evaluate_output_names = combine_evaluate_output_names (* exprs ),
213144 alias_output_names = combine_alias_output_names (* exprs ),
145+ context = self ,
214146 )
215147
216148 def any_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr :
217149 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
218150 series = chain .from_iterable (expr (df ) for expr in exprs )
219151 return [reduce (operator .or_ , align_series_full_broadcast (* series ))]
220152
221- return self ._create_expr_from_callable (
153+ return self ._expr . _from_callable (
222154 func = func ,
223155 depth = max (x ._depth for x in exprs ) + 1 ,
224156 function_name = "any_horizontal" ,
225157 evaluate_output_names = combine_evaluate_output_names (* exprs ),
226158 alias_output_names = combine_alias_output_names (* exprs ),
159+ context = self ,
227160 )
228161
229162 def sum_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr :
@@ -232,12 +165,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
232165 series = (s .fill_null (0 , strategy = None , limit = None ) for s in it )
233166 return [reduce (operator .add , align_series_full_broadcast (* series ))]
234167
235- return self ._create_expr_from_callable (
168+ return self ._expr . _from_callable (
236169 func = func ,
237170 depth = max (x ._depth for x in exprs ) + 1 ,
238171 function_name = "sum_horizontal" ,
239172 evaluate_output_names = combine_evaluate_output_names (* exprs ),
240173 alias_output_names = combine_alias_output_names (* exprs ),
174+ context = self ,
241175 )
242176
243177 def mean_horizontal (self : Self , * exprs : ArrowExpr ) -> IntoArrowExpr :
@@ -253,12 +187,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
253187 )
254188 return [reduce (operator .add , series ) / reduce (operator .add , non_na )]
255189
256- return self ._create_expr_from_callable (
190+ return self ._expr . _from_callable (
257191 func = func ,
258192 depth = max (x ._depth for x in exprs ) + 1 ,
259193 function_name = "mean_horizontal" ,
260194 evaluate_output_names = combine_evaluate_output_names (* exprs ),
261195 alias_output_names = combine_alias_output_names (* exprs ),
196+ context = self ,
262197 )
263198
264199 def min_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr :
@@ -281,12 +216,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
281216 )
282217 ]
283218
284- return self ._create_expr_from_callable (
219+ return self ._expr . _from_callable (
285220 func = func ,
286221 depth = max (x ._depth for x in exprs ) + 1 ,
287222 function_name = "min_horizontal" ,
288223 evaluate_output_names = combine_evaluate_output_names (* exprs ),
289224 alias_output_names = combine_alias_output_names (* exprs ),
225+ context = self ,
290226 )
291227
292228 def max_horizontal (self : Self , * exprs : ArrowExpr ) -> ArrowExpr :
@@ -310,12 +246,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
310246 )
311247 ]
312248
313- return self ._create_expr_from_callable (
249+ return self ._expr . _from_callable (
314250 func = func ,
315251 depth = max (x ._depth for x in exprs ) + 1 ,
316252 function_name = "max_horizontal" ,
317253 evaluate_output_names = combine_evaluate_output_names (* exprs ),
318254 alias_output_names = combine_alias_output_names (* exprs ),
255+ context = self ,
319256 )
320257
321258 def concat (
@@ -381,12 +318,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
381318 )
382319 ]
383320
384- return self ._create_expr_from_callable (
321+ return self ._expr . _from_callable (
385322 func = func ,
386323 depth = max (x ._depth for x in exprs ) + 1 ,
387324 function_name = "concat_str" ,
388325 evaluate_output_names = combine_evaluate_output_names (* exprs ),
389326 alias_output_names = combine_alias_output_names (* exprs ),
327+ context = self ,
390328 )
391329
392330
@@ -407,16 +345,13 @@ def __init__(
407345 self ._version = version
408346
409347 def __call__ (self : Self , df : ArrowDataFrame ) -> Sequence [ArrowSeries ]:
410- plx = df .__narwhals_namespace__ ()
411348 condition = self ._condition (df )[0 ]
412349 condition_native = condition ._native_series
413350
414351 if isinstance (self ._then_value , ArrowExpr ):
415352 value_series = self ._then_value (df )[0 ]
416353 else :
417- value_series = plx ._create_series_from_scalar (
418- self ._then_value , reference_series = condition .alias ("literal" )
419- )
354+ value_series = condition .alias ("literal" )._from_scalar (self ._then_value )
420355 value_series ._broadcast = True
421356 value_series_native = extract_dataframe_comparand (
422357 len (df ), value_series , self ._backend_version
@@ -474,6 +409,7 @@ def __init__(
474409 backend_version : tuple [int , ...],
475410 version : Version ,
476411 call_kwargs : dict [str , Any ] | None = None ,
412+ implementation : Implementation | None = None ,
477413 ) -> None :
478414 self ._backend_version = backend_version
479415 self ._version = version
0 commit comments