diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 79740752e9..50fc56559c 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2175,6 +2175,16 @@ def list_distinct(self) -> Expression: return list_distinct(self) + def list_contains(self, item: Expression) -> Expression: + """Checks if each list contains the specified item. + + Tip: See Also + [`daft.functions.list_contains`](https://docs.daft.ai/en/stable/api/functions/list_contains/) + """ + from daft.functions import list_contains + + return list_contains(self, item) + def list_map(self, mapper: Expression) -> Expression: """Evaluates an expression on all elements in the list. diff --git a/daft/functions/__init__.py b/daft/functions/__init__.py index c210a4932c..c75d7b0a87 100644 --- a/daft/functions/__init__.py +++ b/daft/functions/__init__.py @@ -95,6 +95,7 @@ list_map, explode, list_append, + list_contains, to_list, ) from .llm import llm_generate @@ -322,6 +323,7 @@ "list_append", "list_bool_and", "list_bool_or", + "list_contains", "list_count", "list_distinct", "list_join", diff --git a/daft/functions/list.py b/daft/functions/list.py index 47cdb3ffdd..7be122dce6 100644 --- a/daft/functions/list.py +++ b/daft/functions/list.py @@ -458,6 +458,52 @@ def list_append(list_expr: Expression, other: Expression) -> Expression: return Expression._call_builtin_scalar_fn("list_append", list_expr, other) +def list_contains(list_expr: Expression, item: Expression) -> Expression: + """Checks if each list contains the specified item. + + Args: + list_expr: expression to search in + item: value or column of values to search for + + Returns: + Boolean expression indicating whether each list contains the item + + Examples: + >>> import daft + >>> from daft.functions import list_contains + >>> + >>> df = daft.from_pydict({"a": [[1, 2, 3], [2, 4], [1, 3, 5], []]}) + >>> df.where(list_contains(df["a"], 3)).show() + ╭─────────────╮ + │ a │ + │ --- │ + │ List[Int64] │ + ╞═════════════╡ + │ [1, 2, 3] │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ [1, 3, 5] │ + ╰─────────────╯ + + (Showing first 2 of 2 rows) + + >>> # Check against another column + >>> df2 = daft.from_pydict({"lists": [[1, 2], [3, 4]], "items": [1, 4]}) + >>> df2.with_column("match", list_contains(df2["lists"], df2["items"])).show() + ╭─────────────┬───────┬───────╮ + │ lists ┆ items ┆ match │ + │ --- ┆ --- ┆ --- │ + │ List[Int64] ┆ Int64 ┆ Bool │ + ╞═════════════╪═══════╪═══════╡ + │ [1, 2] ┆ 1 ┆ true │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [3, 4] ┆ 4 ┆ true │ + ╰─────────────┴───────┴───────╯ + + (Showing first 2 of 2 rows) + """ + return Expression._call_builtin_scalar_fn("list_contains", list_expr, item) + + def to_list(*items: Expression) -> Expression: """Constructs a list from the item expressions. diff --git a/daft/series.py b/daft/series.py index bd36c325d8..bdbbf9a701 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1113,6 +1113,9 @@ def get(self, idx: Series, default: Series) -> Series: def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series: return self._eval_expressions("list_sort", desc=desc, nulls_first=nulls_first) + def contains(self, item: Series) -> Series: + return self._eval_expressions("list_contains", item=item) + class SeriesMapNamespace(SeriesNamespace): def get(self, key: Series) -> Series: diff --git a/src/daft-functions-list/src/contains.rs b/src/daft-functions-list/src/contains.rs new file mode 100644 index 0000000000..21650af9d0 --- /dev/null +++ b/src/daft-functions-list/src/contains.rs @@ -0,0 +1,79 @@ +use common_error::{DaftResult, ensure}; +use daft_core::{ + datatypes::DataType, + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + ExprRef, + functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, +}; +use serde::{Deserialize, Serialize}; + +use crate::series::SeriesListExtension; + +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListContains; + +#[typetag::serde] +impl ScalarUDF for ListContains { + fn name(&self) -> &'static str { + "list_contains" + } + + fn call(&self, inputs: daft_dsl::functions::FunctionArgs) -> DaftResult { + let list_series = inputs.required((0, "list"))?; + let item = inputs.required((1, "item"))?; + + let item = if item.len() == 1 { + &item.broadcast(list_series.len())? + } else { + item + }; + + ensure!( + item.len() == list_series.len(), + ValueError: "Item length must match list length" + ); + + list_series.list_contains(item) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 input args, got {}", + inputs.len() + ); + + let list_field = inputs.required((0, "list"))?.to_field(schema)?; + let item_field = inputs.required((1, "item"))?.to_field(schema)?; + + ensure!( + list_field.dtype.is_list() || list_field.dtype.is_fixed_size_list(), + TypeError: "First argument must be a list, got {}", + list_field.dtype + ); + + let list_element_field = list_field.to_exploded_field()?; + if !list_element_field.dtype.is_null() { + ensure!( + list_element_field.dtype == item_field.dtype || item_field.dtype.is_null(), + TypeError: "Cannot search for item of type {} in list of type {}", + item_field.dtype, + list_element_field.dtype + ); + } + + Ok(Field::new(list_field.name, DataType::Boolean)) + } +} + +#[must_use] +pub fn list_contains(list_expr: ExprRef, item: ExprRef) -> ExprRef { + ScalarFn::builtin(ListContains, vec![list_expr, item]).into() +} diff --git a/src/daft-functions-list/src/kernels.rs b/src/daft-functions-list/src/kernels.rs index 8e621ac7da..ffd548ae8f 100644 --- a/src/daft-functions-list/src/kernels.rs +++ b/src/daft-functions-list/src/kernels.rs @@ -2,7 +2,7 @@ use std::{iter::repeat_n, sync::Arc}; -use arrow::array::{BooleanBufferBuilder, make_comparator}; +use arrow::array::{BooleanBufferBuilder, BooleanBuilder, make_comparator}; use common_error::DaftResult; use daft_arrow::offset::{Offsets, OffsetsBuffer}; use daft_core::{ @@ -32,6 +32,7 @@ pub trait ListArrayExtension: Sized { fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult; fn list_bool_and(&self) -> DaftResult; fn list_bool_or(&self) -> DaftResult; + fn list_contains(&self, item: &Series) -> DaftResult; } pub trait ListArrayAggExtension: Sized { @@ -472,6 +473,71 @@ impl ListArrayExtension for ListArray { .rename(self.name()) .with_nulls(Some(null_buffer)) } + + fn list_contains(&self, item: &Series) -> DaftResult { + let list_nulls = self.nulls(); + let mut builder = BooleanBuilder::new(); + let field = Field::new(self.name(), DataType::Boolean); + + if self.flat_child.data_type() == &DataType::Null { + for list_idx in 0..self.len() { + let valid = list_nulls.is_none_or(|nulls| nulls.is_valid(list_idx)) + && item.is_valid(list_idx); + if valid { + builder.append_value(false); + } else { + builder.append_null(); + } + } + + let arrow_array = Arc::new(builder.finish()); + return BooleanArray::from_arrow(field, arrow_array); + } + + let item = item.cast(self.flat_child.data_type())?; + let item_hashes = item.hash(None)?; + let child_hashes = self.flat_child.hash(None)?; + + let child_arrow = self.flat_child.to_arrow()?; + let item_arrow = item.to_arrow()?; + let comparator = make_comparator( + child_arrow.as_ref(), + item_arrow.as_ref(), + Default::default(), + ) + .unwrap(); + + for (list_idx, range) in self.offsets().ranges().enumerate() { + if list_nulls.is_some_and(|nulls| nulls.is_null(list_idx)) { + builder.append_null(); + continue; + } + + if !item.is_valid(list_idx) { + builder.append_null(); + continue; + } + + let item_hash = item_hashes.get(list_idx).unwrap(); + let mut found = false; + for elem_idx in range { + let elem_idx = elem_idx as usize; + if child_arrow.is_null(elem_idx) { + continue; + } + let elem_hash = child_hashes.get(elem_idx).unwrap(); + if elem_hash == item_hash && comparator(elem_idx, list_idx).is_eq() { + found = true; + break; + } + } + + builder.append_value(found); + } + + let arrow_array = Arc::new(builder.finish()); + BooleanArray::from_arrow(field, arrow_array) + } } impl ListArrayExtension for FixedSizeListArray { @@ -483,6 +549,10 @@ impl ListArrayExtension for FixedSizeListArray { self.to_list().list_bool_or() } + fn list_contains(&self, item: &Series) -> DaftResult { + self.to_list().list_contains(item) + } + fn value_counts(&self) -> DaftResult { self.to_list().value_counts() } diff --git a/src/daft-functions-list/src/lib.rs b/src/daft-functions-list/src/lib.rs index 28bd81820a..1e1dcdf763 100644 --- a/src/daft-functions-list/src/lib.rs +++ b/src/daft-functions-list/src/lib.rs @@ -2,6 +2,7 @@ mod append; mod bool_and; mod bool_or; mod chunk; +mod contains; mod count; mod count_distinct; mod distinct; @@ -21,6 +22,7 @@ pub use append::{ListAppend, list_append as append}; pub use bool_and::{ListBoolAnd, list_bool_and as bool_and}; pub use bool_or::{ListBoolOr, list_bool_or as bool_or}; pub use chunk::{ListChunk, list_chunk as chunk}; +pub use contains::{ListContains, list_contains as contains}; pub use count::ListCount; pub use count_distinct::{ListCountDistinct, list_count_distinct as count_distinct}; pub use distinct::{ListDistinct, list_distinct as distinct}; @@ -50,6 +52,7 @@ impl FunctionModule for ListFunctions { parent.add_fn(ListBoolAnd); parent.add_fn(ListBoolOr); parent.add_fn(ListChunk); + parent.add_fn(ListContains); parent.add_fn(ListCount); parent.add_fn(ListCountDistinct); parent.add_fn(ListDistinct); diff --git a/src/daft-functions-list/src/series.rs b/src/daft-functions-list/src/series.rs index 69a762e382..c311dc828a 100644 --- a/src/daft-functions-list/src/series.rs +++ b/src/daft-functions-list/src/series.rs @@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized { fn list_fill(&self, num: &Int64Array) -> DaftResult; fn list_distinct(&self) -> DaftResult; fn list_append(&self, other: &Self) -> DaftResult; + fn list_contains(&self, item: &Self) -> DaftResult; } impl SeriesListExtension for Series { @@ -362,4 +363,16 @@ impl SeriesListExtension for Series { Ok(list_array.into_series()) } + + fn list_contains(&self, item: &Self) -> DaftResult { + match self.data_type() { + DataType::List(_) => Ok(self.list()?.list_contains(item)?.into_series()), + DataType::FixedSizeList(..) => { + Ok(self.fixed_size_list()?.list_contains(item)?.into_series()) + } + dt => Err(DaftError::TypeError(format!( + "List contains not implemented for {dt}" + ))), + } + } } diff --git a/tests/recordbatch/list/test_list_contains.py b/tests/recordbatch/list/test_list_contains.py new file mode 100644 index 0000000000..d31217e98e --- /dev/null +++ b/tests/recordbatch/list/test_list_contains.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pytest + +from daft import col +from daft.datatype import DataType +from daft.recordbatch import MicroPartition + + +def test_list_contains_scalar_and_broadcast(): + table = MicroPartition.from_pydict({"a": [[1, 2, 3], [4, 5, 6], [], [1]]}) + result = table.eval_expression_list([col("a").list_contains(1)]) + assert result.to_pydict()["a"] == [True, False, False, True] + + +def test_list_contains_column_items(): + table = MicroPartition.from_pydict({"lists": [[1, 2], [3, 4], [5, 6]], "items": [1, 4, 7]}) + result = table.eval_expression_list([col("lists").list_contains(col("items"))]) + assert result.to_pydict()["lists"] == [True, True, False] + + +def test_list_contains_null_list_and_item(): + table = MicroPartition.from_pydict({"lists": [[1, 2], None, [3, 4]], "items": [2, 2, None]}) + result = table.eval_expression_list([col("lists").list_contains(col("items"))]) + assert result.to_pydict()["lists"] == [True, None, None] + + +@pytest.mark.parametrize( + "data,search_value,expected", + [ + pytest.param([[1, 2, 3], [4, 5], [1]], 1, [True, False, True], id="int"), + pytest.param([[1.0, 2.0], [3.0, 4.0]], 2.0, [True, False], id="float"), + pytest.param([["a", "b"], ["c", "d"]], "a", [True, False], id="string"), + pytest.param([[True, False], [False, False]], True, [True, False], id="bool"), + ], +) +def test_list_contains_types(data, search_value, expected): + table = MicroPartition.from_pydict({"a": data}) + result = table.eval_expression_list([col("a").list_contains(search_value)]) + assert result.to_pydict()["a"] == expected + + +def test_list_contains_nulls_in_list(): + table = MicroPartition.from_pydict({"a": [[1, None, 3], [None, None], [None, 5]]}) + result = table.eval_expression_list([col("a").list_contains(3)]) + assert result.to_pydict()["a"] == [True, False, False] + + +def test_list_contains_list_null_dtype(): + table = MicroPartition.from_pydict({"a": [[None, None], [None], []], "items": [1, None, 1]}) + result = table.eval_expression_list([col("a").list_contains(col("items"))]) + assert result.to_pydict()["a"] == [False, None, False] + + +def test_fixed_size_list_contains(): + table = MicroPartition.from_pydict({"col": [["a", "b"], ["c", "d"], ["a", "c"]]}) + fixed_dtype = DataType.fixed_size_list(DataType.string(), 2) + table = table.eval_expression_list([col("col").cast(fixed_dtype)]) + + result = table.eval_expression_list([col("col").list_contains("a")]) + assert result.to_pydict()["col"] == [True, False, True] + + +def test_list_contains_varying_lengths(): + table = MicroPartition.from_pydict({"a": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]}) + result = table.eval_expression_list([col("a").list_contains(3)]) + assert result.to_pydict()["a"] == [False, False, True, True] diff --git a/tests/series/test_list.py b/tests/series/test_list.py index cc39ed60cf..b5a7325c97 100644 --- a/tests/series/test_list.py +++ b/tests/series/test_list.py @@ -86,3 +86,12 @@ def test_list_list_sort_multi_desc(fixed): [40, 30], ] assert res.to_pylist() == expected + + +def test_list_contains_series(): + data = Series.from_pylist([[1, 2], [3, 4], None, []]) + items = Series.from_pylist([2, 5, 1, 1]) + + result = data.list.contains(items) + + assert result.to_pylist() == [True, False, None, False]