Skip to content

Commit dc9fcaa

Browse files
fix: Allow exprs in .filter constraints (#2114)
* fix: allow exprs in filter constraints * refactor(suggestion): Tidy-up constraints --------- Co-authored-by: dangotbanned <[email protected]>
1 parent be0f7b1 commit dc9fcaa

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

narwhals/dataframe.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,21 +202,22 @@ def filter(
202202
and isinstance(predicates[0], list)
203203
and all(isinstance(x, bool) for x in predicates[0])
204204
):
205+
from narwhals.functions import col
206+
205207
flat_predicates = flatten(predicates)
206208
check_expressions_preserve_length(*flat_predicates, function_name="filter")
207-
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
208209
plx = self.__narwhals_namespace__()
210+
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
211+
compliant_constraints = (
212+
(col(name) == v)._to_compliant_expr(plx)
213+
for name, v in constraints.items()
214+
)
209215
predicate = plx.all_horizontal(
210-
*chain(
211-
compliant_predicates,
212-
(plx.col(name) == v for name, v in constraints.items()),
213-
)
216+
*chain(compliant_predicates, compliant_constraints)
214217
)
215218
else:
216219
predicate = predicates[0]
217-
return self._from_compliant_dataframe(
218-
self._compliant_frame.filter(predicate),
219-
)
220+
return self._from_compliant_dataframe(self._compliant_frame.filter(predicate))
220221

221222
def sort(
222223
self: Self,

tests/frame/filter_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,17 @@ def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None:
5454
df = nw.from_native(constructor(data))
5555
with pytest.raises((LengthChangingExprError, ShapeError)):
5656
df.filter(nw.col("b").unique() > 2).lazy().collect()
57+
58+
59+
def test_filter_with_constrains(constructor: Constructor) -> None:
60+
data = {"a": [1, 3, 2], "b": [4, 4, 6]}
61+
df = nw.from_native(constructor(data))
62+
result_scalar = df.filter(a=3)
63+
expected_scalar = {"a": [3], "b": [4]}
64+
65+
assert_equal_data(result_scalar, expected_scalar)
66+
67+
result_expr = df.filter(a=nw.col("b") // 3)
68+
expected_expr = {"a": [1, 2], "b": [4, 6]}
69+
70+
assert_equal_data(result_expr, expected_expr)

0 commit comments

Comments
 (0)