|
3 | 3 | import operator |
4 | 4 | from functools import reduce |
5 | 5 | from itertools import chain |
6 | | -from typing import TYPE_CHECKING, Any, cast |
| 6 | +from typing import TYPE_CHECKING, Any, Callable, cast |
7 | 7 |
|
8 | 8 | import ibis |
9 | 9 | import ibis.expr.types as ir |
@@ -33,6 +33,21 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non |
33 | 33 | self._backend_version = backend_version |
34 | 34 | self._version = version |
35 | 35 |
|
| 36 | + def _expr_from_callable( |
| 37 | + self, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: IbisExpr |
| 38 | + ) -> IbisExpr: |
| 39 | + def call(df: IbisLazyFrame) -> list[ir.Value]: |
| 40 | + cols = (col for _expr in exprs for col in _expr(df)) |
| 41 | + return [func(cols)] |
| 42 | + |
| 43 | + return self._expr( |
| 44 | + call=call, |
| 45 | + evaluate_output_names=combine_evaluate_output_names(*exprs), |
| 46 | + alias_output_names=combine_alias_output_names(*exprs), |
| 47 | + backend_version=self._backend_version, |
| 48 | + version=self._version, |
| 49 | + ) |
| 50 | + |
36 | 51 | @property |
37 | 52 | def selectors(self) -> IbisSelectorNamespace: |
38 | 53 | return IbisSelectorNamespace.from_namespace(self) |
@@ -86,94 +101,44 @@ def func(df: IbisLazyFrame) -> list[ir.Value]: |
86 | 101 | ) |
87 | 102 |
|
88 | 103 | def all_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
89 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
90 | | - cols = chain.from_iterable(expr(df) for expr in exprs) |
91 | | - return [reduce(operator.and_, cols)] |
| 104 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 105 | + return reduce(operator.and_, cols) |
92 | 106 |
|
93 | | - return self._expr( |
94 | | - call=func, |
95 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
96 | | - alias_output_names=combine_alias_output_names(*exprs), |
97 | | - backend_version=self._backend_version, |
98 | | - version=self._version, |
99 | | - ) |
| 107 | + return self._expr_from_callable(func, *exprs) |
100 | 108 |
|
101 | 109 | def any_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
102 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
103 | | - cols = chain.from_iterable(expr(df) for expr in exprs) |
104 | | - return [reduce(operator.or_, cols)] |
| 110 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 111 | + return reduce(operator.or_, cols) |
105 | 112 |
|
106 | | - return self._expr( |
107 | | - call=func, |
108 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
109 | | - alias_output_names=combine_alias_output_names(*exprs), |
110 | | - backend_version=self._backend_version, |
111 | | - version=self._version, |
112 | | - ) |
| 113 | + return self._expr_from_callable(func, *exprs) |
113 | 114 |
|
114 | 115 | def max_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
115 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
116 | | - cols = chain.from_iterable(expr(df) for expr in exprs) |
117 | | - return [ibis.greatest(*cols)] |
| 116 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 117 | + return ibis.greatest(*cols) |
118 | 118 |
|
119 | | - return self._expr( |
120 | | - call=func, |
121 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
122 | | - alias_output_names=combine_alias_output_names(*exprs), |
123 | | - backend_version=self._backend_version, |
124 | | - version=self._version, |
125 | | - ) |
| 119 | + return self._expr_from_callable(func, *exprs) |
126 | 120 |
|
127 | 121 | def min_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
128 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
129 | | - cols = chain.from_iterable(expr(df) for expr in exprs) |
130 | | - return [ibis.least(*cols)] |
| 122 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 123 | + return ibis.least(*cols) |
131 | 124 |
|
132 | | - return self._expr( |
133 | | - call=func, |
134 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
135 | | - alias_output_names=combine_alias_output_names(*exprs), |
136 | | - backend_version=self._backend_version, |
137 | | - version=self._version, |
138 | | - ) |
| 125 | + return self._expr_from_callable(func, *exprs) |
139 | 126 |
|
140 | 127 | def sum_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
141 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
142 | | - cols = [e.fill_null(lit(0)) for _expr in exprs for e in _expr(df)] |
143 | | - return [reduce(operator.add, cols)] |
| 128 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 129 | + cols = (col.fill_null(lit(0)) for col in cols) |
| 130 | + return reduce(operator.add, cols) |
144 | 131 |
|
145 | | - return self._expr( |
146 | | - call=func, |
147 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
148 | | - alias_output_names=combine_alias_output_names(*exprs), |
149 | | - backend_version=self._backend_version, |
150 | | - version=self._version, |
151 | | - ) |
| 132 | + return self._expr_from_callable(func, *exprs) |
152 | 133 |
|
153 | 134 | def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr: |
154 | | - def func(df: IbisLazyFrame) -> list[ir.Value]: |
155 | | - expr = ( |
156 | | - cast("ir.NumericColumn", e.fill_null(lit(0))) |
157 | | - for _expr in exprs |
158 | | - for e in _expr(df) |
159 | | - ) |
160 | | - non_null = ( |
161 | | - cast("ir.NumericColumn", e.isnull().ifelse(lit(0), lit(1))) |
162 | | - for _expr in exprs |
163 | | - for e in _expr(df) |
| 135 | + def func(cols: Iterable[ir.Value]) -> ir.Value: |
| 136 | + cols = list(cols) |
| 137 | + return reduce(operator.add, (col.fill_null(lit(0)) for col in cols)) / reduce( |
| 138 | + operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols) |
164 | 139 | ) |
165 | 140 |
|
166 | | - return [ |
167 | | - (reduce(lambda x, y: x + y, expr) / reduce(lambda x, y: x + y, non_null)) |
168 | | - ] |
169 | | - |
170 | | - return self._expr( |
171 | | - call=func, |
172 | | - evaluate_output_names=combine_evaluate_output_names(*exprs), |
173 | | - alias_output_names=combine_alias_output_names(*exprs), |
174 | | - backend_version=self._backend_version, |
175 | | - version=self._version, |
176 | | - ) |
| 141 | + return self._expr_from_callable(func, *exprs) |
177 | 142 |
|
178 | 143 | @requires.backend_version((10, 0)) |
179 | 144 | def when(self, predicate: IbisExpr) -> IbisWhen: |
|
0 commit comments