Skip to content

Commit 96c5094

Browse files
committed
feat: list_contains expression
1 parent 97bfd49 commit 96c5094

File tree

10 files changed

+311
-0
lines changed

10 files changed

+311
-0
lines changed

daft/expressions/expressions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,16 @@ def list_distinct(self) -> Expression:
21752175

21762176
return list_distinct(self)
21772177

2178+
def list_contains(self, item: Expression) -> Expression:
2179+
"""Checks if each list contains the specified item.
2180+
2181+
Tip: See Also
2182+
[`daft.functions.list_contains`](https://docs.daft.ai/en/stable/api/functions/list_contains/)
2183+
"""
2184+
from daft.functions import list_contains
2185+
2186+
return list_contains(self, item)
2187+
21782188
def list_map(self, mapper: Expression) -> Expression:
21792189
"""Evaluates an expression on all elements in the list.
21802190

daft/functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
list_map,
9696
explode,
9797
list_append,
98+
list_contains,
9899
to_list,
99100
)
100101
from .llm import llm_generate
@@ -322,6 +323,7 @@
322323
"list_append",
323324
"list_bool_and",
324325
"list_bool_or",
326+
"list_contains",
325327
"list_count",
326328
"list_distinct",
327329
"list_join",

daft/functions/list.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,52 @@ def list_append(list_expr: Expression, other: Expression) -> Expression:
458458
return Expression._call_builtin_scalar_fn("list_append", list_expr, other)
459459

460460

461+
def list_contains(list_expr: Expression, item: Expression) -> Expression:
462+
"""Checks if each list contains the specified item.
463+
464+
Args:
465+
list_expr: expression to search in
466+
item: value or column of values to search for
467+
468+
Returns:
469+
Boolean expression indicating whether each list contains the item
470+
471+
Examples:
472+
>>> import daft
473+
>>> from daft.functions import list_contains
474+
>>>
475+
>>> df = daft.from_pydict({"a": [[1, 2, 3], [2, 4], [1, 3, 5], []]})
476+
>>> df.where(list_contains(df["a"], 3)).show()
477+
╭─────────────╮
478+
│ a │
479+
│ --- │
480+
│ List[Int64] │
481+
╞═════════════╡
482+
│ [1, 2, 3] │
483+
├╌╌╌╌╌╌╌╌╌╌╌╌╌┤
484+
│ [1, 3, 5] │
485+
╰─────────────╯
486+
<BLANKLINE>
487+
(Showing first 2 of 2 rows)
488+
489+
>>> # Check against another column
490+
>>> df2 = daft.from_pydict({"lists": [[1, 2], [3, 4]], "items": [1, 4]})
491+
>>> df2.with_column("match", list_contains(df2["lists"], df2["items"])).show()
492+
╭─────────────┬───────┬───────╮
493+
│ lists ┆ items ┆ match │
494+
│ --- ┆ --- ┆ --- │
495+
│ List[Int64] ┆ Int64 ┆ Bool │
496+
╞═════════════╪═══════╪═══════╡
497+
│ [1, 2] ┆ 1 ┆ true │
498+
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
499+
│ [3, 4] ┆ 4 ┆ true │
500+
╰─────────────┴───────┴───────╯
501+
<BLANKLINE>
502+
(Showing first 2 of 2 rows)
503+
"""
504+
return Expression._call_builtin_scalar_fn("list_contains", list_expr, item)
505+
506+
461507
def to_list(*items: Expression) -> Expression:
462508
"""Constructs a list from the item expressions.
463509

daft/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,9 @@ def get(self, idx: Series, default: Series) -> Series:
11131113
def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series:
11141114
return self._eval_expressions("list_sort", desc=desc, nulls_first=nulls_first)
11151115

1116+
def contains(self, item: Series) -> Series:
1117+
return self._eval_expressions("list_contains", item=item)
1118+
11161119

11171120
class SeriesMapNamespace(SeriesNamespace):
11181121
def get(self, key: Series) -> Series:
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use common_error::{DaftResult, ensure};
2+
use daft_core::{
3+
datatypes::DataType,
4+
prelude::{Field, Schema},
5+
series::Series,
6+
};
7+
use daft_dsl::{
8+
ExprRef,
9+
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
10+
};
11+
use serde::{Deserialize, Serialize};
12+
13+
use crate::series::SeriesListExtension;
14+
15+
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
16+
pub struct ListContains;
17+
18+
#[typetag::serde]
19+
impl ScalarUDF for ListContains {
20+
fn name(&self) -> &'static str {
21+
"list_contains"
22+
}
23+
24+
fn call(&self, inputs: daft_dsl::functions::FunctionArgs<Series>) -> DaftResult<Series> {
25+
let list_series = inputs.required((0, "list"))?;
26+
let item = inputs.required((1, "item"))?;
27+
28+
let item = if item.len() == 1 {
29+
&item.broadcast(list_series.len())?
30+
} else {
31+
item
32+
};
33+
34+
ensure!(
35+
item.len() == list_series.len(),
36+
ValueError: "Item length must match list length"
37+
);
38+
39+
list_series.list_contains(item)
40+
}
41+
42+
fn get_return_field(
43+
&self,
44+
inputs: FunctionArgs<ExprRef>,
45+
schema: &Schema,
46+
) -> DaftResult<Field> {
47+
ensure!(
48+
inputs.len() == 2,
49+
SchemaMismatch: "Expected 2 input args, got {}",
50+
inputs.len()
51+
);
52+
53+
let list_field = inputs.required((0, "list"))?.to_field(schema)?;
54+
let item_field = inputs.required((1, "item"))?.to_field(schema)?;
55+
56+
ensure!(
57+
list_field.dtype.is_list() || list_field.dtype.is_fixed_size_list(),
58+
TypeError: "First argument must be a list, got {}",
59+
list_field.dtype
60+
);
61+
62+
let list_element_field = list_field.to_exploded_field()?;
63+
if !list_element_field.dtype.is_null() {
64+
ensure!(
65+
list_element_field.dtype == item_field.dtype || item_field.dtype.is_null(),
66+
TypeError: "Cannot search for item of type {} in list of type {}",
67+
item_field.dtype,
68+
list_element_field.dtype
69+
);
70+
}
71+
72+
Ok(Field::new(list_field.name, DataType::Boolean))
73+
}
74+
}
75+
76+
#[must_use]
77+
pub fn list_contains(list_expr: ExprRef, item: ExprRef) -> ExprRef {
78+
ScalarFn::builtin(ListContains, vec![list_expr, item]).into()
79+
}

src/daft-functions-list/src/kernels.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub trait ListArrayExtension: Sized {
3232
fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult<Self>;
3333
fn list_bool_and(&self) -> DaftResult<BooleanArray>;
3434
fn list_bool_or(&self) -> DaftResult<BooleanArray>;
35+
fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray>;
3536
}
3637

3738
pub trait ListArrayAggExtension: Sized {
@@ -472,6 +473,80 @@ impl ListArrayExtension for ListArray {
472473
.rename(self.name())
473474
.with_nulls(Some(null_buffer))
474475
}
476+
477+
fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray> {
478+
let list_nulls = self.nulls();
479+
let mut result = Vec::with_capacity(self.len());
480+
let mut result_nulls = Vec::with_capacity(self.len());
481+
482+
if self.flat_child.data_type() == &DataType::Null {
483+
for list_idx in 0..self.len() {
484+
let valid = list_nulls.is_none_or(|nulls| nulls.is_valid(list_idx))
485+
&& item.is_valid(list_idx);
486+
result.push(false);
487+
result_nulls.push(valid);
488+
}
489+
490+
let values = daft_arrow::bitmap::Bitmap::from_iter(result.iter().copied());
491+
let null_buffer =
492+
daft_arrow::buffer::NullBuffer::from_iter(result_nulls.iter().copied());
493+
494+
return BooleanArray::from_iter_values(values)
495+
.rename(self.name())
496+
.with_nulls(Some(null_buffer));
497+
}
498+
499+
let item = item.cast(self.flat_child.data_type())?;
500+
let item_hashes = item.hash(None)?;
501+
let child_hashes = self.flat_child.hash(None)?;
502+
503+
let child_arrow = self.flat_child.to_arrow()?;
504+
let item_arrow = item.to_arrow()?;
505+
let comparator = make_comparator(
506+
child_arrow.as_ref(),
507+
item_arrow.as_ref(),
508+
Default::default(),
509+
)
510+
.unwrap();
511+
512+
for (list_idx, range) in self.offsets().ranges().enumerate() {
513+
if list_nulls.is_some_and(|nulls| nulls.is_null(list_idx)) {
514+
result.push(false);
515+
result_nulls.push(false);
516+
continue;
517+
}
518+
519+
if !item.is_valid(list_idx) {
520+
result.push(false);
521+
result_nulls.push(false);
522+
continue;
523+
}
524+
525+
let item_hash = item_hashes.get(list_idx).unwrap();
526+
let mut found = false;
527+
for elem_idx in range {
528+
let elem_idx = elem_idx as usize;
529+
if child_arrow.is_null(elem_idx) {
530+
continue;
531+
}
532+
let elem_hash = child_hashes.get(elem_idx).unwrap();
533+
if elem_hash == item_hash && comparator(elem_idx, list_idx).is_eq() {
534+
found = true;
535+
break;
536+
}
537+
}
538+
539+
result.push(found);
540+
result_nulls.push(true);
541+
}
542+
543+
let values = daft_arrow::bitmap::Bitmap::from_iter(result.iter().copied());
544+
let null_buffer = daft_arrow::buffer::NullBuffer::from_iter(result_nulls.iter().copied());
545+
546+
BooleanArray::from_iter_values(values)
547+
.rename(self.name())
548+
.with_nulls(Some(null_buffer))
549+
}
475550
}
476551

477552
impl ListArrayExtension for FixedSizeListArray {
@@ -483,6 +558,10 @@ impl ListArrayExtension for FixedSizeListArray {
483558
self.to_list().list_bool_or()
484559
}
485560

561+
fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray> {
562+
self.to_list().list_contains(item)
563+
}
564+
486565
fn value_counts(&self) -> DaftResult<MapArray> {
487566
self.to_list().value_counts()
488567
}

src/daft-functions-list/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod append;
22
mod bool_and;
33
mod bool_or;
44
mod chunk;
5+
mod contains;
56
mod count;
67
mod count_distinct;
78
mod distinct;
@@ -21,6 +22,7 @@ pub use append::{ListAppend, list_append as append};
2122
pub use bool_and::{ListBoolAnd, list_bool_and as bool_and};
2223
pub use bool_or::{ListBoolOr, list_bool_or as bool_or};
2324
pub use chunk::{ListChunk, list_chunk as chunk};
25+
pub use contains::{ListContains, list_contains as contains};
2426
pub use count::ListCount;
2527
pub use count_distinct::{ListCountDistinct, list_count_distinct as count_distinct};
2628
pub use distinct::{ListDistinct, list_distinct as distinct};
@@ -50,6 +52,7 @@ impl FunctionModule for ListFunctions {
5052
parent.add_fn(ListBoolAnd);
5153
parent.add_fn(ListBoolOr);
5254
parent.add_fn(ListChunk);
55+
parent.add_fn(ListContains);
5356
parent.add_fn(ListCount);
5457
parent.add_fn(ListCountDistinct);
5558
parent.add_fn(ListDistinct);

src/daft-functions-list/src/series.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized {
2929
fn list_fill(&self, num: &Int64Array) -> DaftResult<Self>;
3030
fn list_distinct(&self) -> DaftResult<Self>;
3131
fn list_append(&self, other: &Self) -> DaftResult<Self>;
32+
fn list_contains(&self, item: &Self) -> DaftResult<Self>;
3233
}
3334

3435
impl SeriesListExtension for Series {
@@ -362,4 +363,16 @@ impl SeriesListExtension for Series {
362363

363364
Ok(list_array.into_series())
364365
}
366+
367+
fn list_contains(&self, item: &Self) -> DaftResult<Self> {
368+
match self.data_type() {
369+
DataType::List(_) => Ok(self.list()?.list_contains(item)?.into_series()),
370+
DataType::FixedSizeList(..) => {
371+
Ok(self.fixed_size_list()?.list_contains(item)?.into_series())
372+
}
373+
dt => Err(DaftError::TypeError(format!(
374+
"List contains not implemented for {dt}"
375+
))),
376+
}
377+
}
365378
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from daft import col
6+
from daft.datatype import DataType
7+
from daft.recordbatch import MicroPartition
8+
9+
10+
def test_list_contains_scalar_and_broadcast():
11+
table = MicroPartition.from_pydict({"a": [[1, 2, 3], [4, 5, 6], [], [1]]})
12+
result = table.eval_expression_list([col("a").list_contains(1)])
13+
assert result.to_pydict()["a"] == [True, False, False, True]
14+
15+
16+
def test_list_contains_column_items():
17+
table = MicroPartition.from_pydict({"lists": [[1, 2], [3, 4], [5, 6]], "items": [1, 4, 7]})
18+
result = table.eval_expression_list([col("lists").list_contains(col("items"))])
19+
assert result.to_pydict()["lists"] == [True, True, False]
20+
21+
22+
def test_list_contains_null_list_and_item():
23+
table = MicroPartition.from_pydict({"lists": [[1, 2], None, [3, 4]], "items": [2, 2, None]})
24+
result = table.eval_expression_list([col("lists").list_contains(col("items"))])
25+
assert result.to_pydict()["lists"] == [True, None, None]
26+
27+
28+
@pytest.mark.parametrize(
29+
"data,search_value,expected",
30+
[
31+
pytest.param([[1, 2, 3], [4, 5], [1]], 1, [True, False, True], id="int"),
32+
pytest.param([[1.0, 2.0], [3.0, 4.0]], 2.0, [True, False], id="float"),
33+
pytest.param([["a", "b"], ["c", "d"]], "a", [True, False], id="string"),
34+
pytest.param([[True, False], [False, False]], True, [True, False], id="bool"),
35+
],
36+
)
37+
def test_list_contains_types(data, search_value, expected):
38+
table = MicroPartition.from_pydict({"a": data})
39+
result = table.eval_expression_list([col("a").list_contains(search_value)])
40+
assert result.to_pydict()["a"] == expected
41+
42+
43+
def test_list_contains_nulls_in_list():
44+
table = MicroPartition.from_pydict({"a": [[1, None, 3], [None, None], [None, 5]]})
45+
result = table.eval_expression_list([col("a").list_contains(3)])
46+
assert result.to_pydict()["a"] == [True, False, False]
47+
48+
49+
def test_list_contains_list_null_dtype():
50+
table = MicroPartition.from_pydict({"a": [[None, None], [None], []], "items": [1, None, 1]})
51+
result = table.eval_expression_list([col("a").list_contains(col("items"))])
52+
assert result.to_pydict()["a"] == [False, None, False]
53+
54+
55+
def test_fixed_size_list_contains():
56+
table = MicroPartition.from_pydict({"col": [["a", "b"], ["c", "d"], ["a", "c"]]})
57+
fixed_dtype = DataType.fixed_size_list(DataType.string(), 2)
58+
table = table.eval_expression_list([col("col").cast(fixed_dtype)])
59+
60+
result = table.eval_expression_list([col("col").list_contains("a")])
61+
assert result.to_pydict()["col"] == [True, False, True]
62+
63+
64+
def test_list_contains_varying_lengths():
65+
table = MicroPartition.from_pydict({"a": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]})
66+
result = table.eval_expression_list([col("a").list_contains(3)])
67+
assert result.to_pydict()["a"] == [False, False, True, True]

0 commit comments

Comments
 (0)