Skip to content

Commit 22fdf75

Browse files
authored
chore: factor out sqrt from SparkLike & DuckDB (#2876)
1 parent 624518c commit 22fdf75

File tree

8 files changed

+45
-30
lines changed

8 files changed

+45
-30
lines changed

narwhals/_duckdb/expr.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,6 @@ def _log(expr: Expression) -> Expression:
303303

304304
return self._with_elementwise(_log)
305305

306-
def sqrt(self) -> Self:
307-
def _sqrt(expr: Expression) -> Expression:
308-
return when(expr < lit(0), lit(float("nan"))).otherwise(F("sqrt", expr))
309-
310-
return self._with_elementwise(_sqrt)
311-
312306
@property
313307
def str(self) -> DuckDBExprStringNamespace:
314308
return DuckDBExprStringNamespace(self)

narwhals/_duckdb/namespace.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ def _function(self, name: str, *args: Expression) -> Expression: # type: ignore
6363
def _lit(self, value: Any) -> Expression:
6464
return lit(value)
6565

66-
def _when(self, condition: Expression, value: Expression) -> Expression:
67-
return when(condition, value)
66+
def _when(
67+
self,
68+
condition: Expression,
69+
value: Expression,
70+
otherwise: Expression | None = None,
71+
) -> Expression:
72+
if otherwise is None:
73+
return when(condition, value)
74+
return when(condition, value).otherwise(otherwise)
6875

6976
def _coalesce(self, *exprs: Expression) -> Expression:
7077
return CoalesceOperator(*exprs)

narwhals/_ibis/expr.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,6 @@ def _log(expr: ir.NumericColumn) -> ir.Value:
335335

336336
return self._with_callable(_log)
337337

338-
def sqrt(self) -> Self:
339-
def _sqrt(expr: ir.NumericColumn) -> ir.Value:
340-
return ibis.cases((expr < lit(0), lit(float("nan"))), else_=expr.sqrt())
341-
342-
return self._with_callable(_sqrt)
343-
344338
@property
345339
def str(self) -> IbisExprStringNamespace:
346340
return IbisExprStringNamespace(self)

narwhals/_ibis/namespace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def _function(self, name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
5151
def _lit(self, value: Any) -> ir.Value:
5252
return lit(value)
5353

54-
def _when(self, condition: ir.Value, value: ir.Value) -> ir.Value:
55-
return ibis.cases((condition, value))
54+
def _when(
55+
self, condition: ir.Value, value: ir.Value, otherwise: ir.Expr | None = None
56+
) -> ir.Value:
57+
if otherwise is None:
58+
return ibis.cases((condition, value))
59+
return ibis.cases((condition, value), else_=otherwise) # pragma: no cover
5660

5761
def _coalesce(self, *exprs: ir.Value) -> ir.Value:
5862
return ibis.coalesce(*exprs)

narwhals/_spark_like/expr.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,6 @@ def _log(expr: Column) -> Column:
387387

388388
return self._with_elementwise(_log)
389389

390-
def sqrt(self) -> Self:
391-
def _sqrt(expr: Column) -> Column:
392-
return self._F.when(expr < 0, self._F.lit(float("nan"))).otherwise(
393-
self._F.sqrt(expr)
394-
)
395-
396-
return self._with_elementwise(_sqrt)
397-
398390
@property
399391
def str(self) -> SparkLikeExprStringNamespace:
400392
return SparkLikeExprStringNamespace(self)

narwhals/_spark_like/namespace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ def _function(self, name: str, *args: Column | PythonLiteral) -> Column:
7373
def _lit(self, value: Any) -> Column:
7474
return self._F.lit(value)
7575

76-
def _when(self, condition: Column, value: Column) -> Column:
77-
return self._F.when(condition, value)
76+
def _when(
77+
self, condition: Column, value: Column, otherwise: Column | None = None
78+
) -> Column:
79+
if otherwise is None:
80+
return self._F.when(condition, value)
81+
return self._F.when(condition, value).otherwise(otherwise)
7882

7983
def _coalesce(self, *exprs: Column) -> Column:
8084
return self._F.coalesce(*exprs)

narwhals/_sql/expr.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,19 @@ def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExpr
181181
def _lit(self, value: Any) -> NativeExprT:
182182
return self.__narwhals_namespace__()._lit(value)
183183

184-
def _when(self, condition: NativeExprT, value: NativeExprT) -> NativeExprT:
185-
return self.__narwhals_namespace__()._when(condition, value)
186-
187184
def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
188185
return self.__narwhals_namespace__()._coalesce(*expr)
189186

190187
def _count_star(self) -> NativeExprT: ...
191188

189+
def _when(
190+
self,
191+
condition: NativeExprT,
192+
value: NativeExprT,
193+
otherwise: NativeExprT | None = None,
194+
) -> NativeExprT:
195+
return self.__narwhals_namespace__()._when(condition, value, otherwise)
196+
192197
def _window_expression(
193198
self,
194199
expr: NativeExprT,
@@ -492,6 +497,16 @@ def round(self, decimals: int) -> Self:
492497
lambda expr: self._function("round", expr, self._lit(decimals))
493498
)
494499

500+
def sqrt(self) -> Self:
501+
def _sqrt(expr: NativeExprT) -> NativeExprT:
502+
return self._when(
503+
expr < self._lit(0), # type: ignore[operator]
504+
self._lit(float("nan")),
505+
self._function("sqrt", expr),
506+
)
507+
508+
return self._with_elementwise(_sqrt)
509+
495510
def exp(self) -> Self:
496511
return self._with_elementwise(lambda expr: self._function("exp", expr))
497512

narwhals/_sql/namespace.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ class SQLNamespace(
2020
):
2121
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT: ...
2222
def _lit(self, value: Any) -> NativeExprT: ...
23-
def _when(self, condition: NativeExprT, value: NativeExprT) -> NativeExprT: ...
23+
def _when(
24+
self,
25+
condition: NativeExprT,
26+
value: NativeExprT,
27+
otherwise: NativeExprT | None = None,
28+
) -> NativeExprT: ...
2429
def _coalesce(self, *exprs: NativeExprT) -> NativeExprT: ...
2530

2631
# Horizontal functions

0 commit comments

Comments
 (0)