Skip to content

Commit 4053603

Browse files
authored
fix: use specialised duckdb functions for std_pop/std_samp/var_pop/var_samp (#2289)
1 parent 4bc6074 commit 4053603

File tree

4 files changed

+44
-11
lines changed

4 files changed

+44
-11
lines changed

narwhals/_duckdb/expr.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,22 +372,38 @@ def len(self: Self) -> Self:
372372
return self._from_call(lambda _input: FunctionExpression("count"))
373373

374374
def std(self: Self, ddof: int) -> Self:
375+
if ddof == 0:
376+
return self._from_call(
377+
lambda _input: FunctionExpression("stddev_pop", _input)
378+
)
379+
if ddof == 1:
380+
return self._from_call(
381+
lambda _input: FunctionExpression("stddev_samp", _input)
382+
)
383+
375384
def _std(_input: duckdb.Expression) -> duckdb.Expression:
376385
n_samples = FunctionExpression("count", _input)
377-
# NOTE: Not implemented Error: Unable to transform python value of type '<class 'duckdb.duckdb.Expression'>' to DuckDB LogicalType
378386
return (
379387
FunctionExpression("stddev_pop", _input)
380388
* FunctionExpression("sqrt", n_samples)
381-
/ (FunctionExpression("sqrt", (n_samples - ddof))) # type: ignore[operator]
389+
/ (FunctionExpression("sqrt", (n_samples - lit(ddof))))
382390
)
383391

384392
return self._from_call(_std)
385393

386394
def var(self: Self, ddof: int) -> Self:
395+
if ddof == 0:
396+
return self._from_call(lambda _input: FunctionExpression("var_pop", _input))
397+
if ddof == 1:
398+
return self._from_call(lambda _input: FunctionExpression("var_samp", _input))
399+
387400
def _var(_input: duckdb.Expression) -> duckdb.Expression:
388401
n_samples = FunctionExpression("count", _input)
389-
# NOTE: Not implemented Error: Unable to transform python value of type '<class 'duckdb.duckdb.Expression'>' to DuckDB LogicalType
390-
return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof) # type: ignore[operator, no-any-return]
402+
return (
403+
FunctionExpression("var_pop", _input)
404+
* n_samples
405+
/ (n_samples - lit(ddof))
406+
)
391407

392408
return self._from_call(_var)
393409

tests/expr_and_series/arithmetic_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,26 @@ def test_arithmetic_series_left_literal(
338338
df = nw.from_native(constructor_eager(data))
339339
result = df.select(getattr(lhs, attr)(nw.col("a")))
340340
assert_equal_data(result, {"literal": expected})
341+
342+
343+
def test_std_broadcating(constructor: Constructor) -> None:
344+
# `std(ddof=2)` fails for duckdb here
345+
df = nw.from_native(constructor({"a": [1, 2, 3]}))
346+
result = df.with_columns(b=nw.col("a").std()).sort("a")
347+
expected = {"a": [1, 2, 3], "b": [1.0, 1.0, 1.0]}
348+
assert_equal_data(result, expected)
349+
result = df.with_columns(b=nw.col("a").var()).sort("a")
350+
expected = {"a": [1, 2, 3], "b": [1.0, 1.0, 1.0]}
351+
assert_equal_data(result, expected)
352+
result = df.with_columns(b=nw.col("a").std(ddof=0)).sort("a")
353+
expected = {
354+
"a": [1, 2, 3],
355+
"b": [0.816496580927726, 0.816496580927726, 0.816496580927726],
356+
}
357+
assert_equal_data(result, expected)
358+
result = df.with_columns(b=nw.col("a").var(ddof=0)).sort("a")
359+
expected = {
360+
"a": [1, 2, 3],
361+
"b": [0.6666666666666666, 0.6666666666666666, 0.6666666666666666],
362+
}
363+
assert_equal_data(result, expected)

tests/expr_and_series/over_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def test_over_single(constructor: Constructor) -> None:
5050

5151

5252
def test_over_std_var(request: pytest.FixtureRequest, constructor: Constructor) -> None:
53-
if "duckdb" in str(constructor):
54-
request.applymarker(pytest.mark.xfail)
5553
if "cudf" in str(constructor):
5654
# https://github.com/rapidsai/cudf/issues/18159
5755
request.applymarker(pytest.mark.xfail)

tests/frame/add_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
from __future__ import annotations
22

3-
import pytest
4-
53
import narwhals.stable.v1 as nw
64
from tests.utils import Constructor
75
from tests.utils import assert_equal_data
86

97

10-
def test_add(constructor: Constructor, request: pytest.FixtureRequest) -> None:
11-
if "duckdb" in str(constructor):
12-
request.applymarker(pytest.mark.xfail)
8+
def test_add(constructor: Constructor) -> None:
139
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
1410
df = nw.from_native(constructor(data))
1511
result = df.with_columns(

0 commit comments

Comments
 (0)