Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,16 @@ def list_distinct(self) -> Expression:

return list_distinct(self)

def list_contains(self, item: Expression) -> Expression:
"""Checks if each list contains the specified item.
Tip: See Also
[`daft.functions.list_contains`](https://docs.daft.ai/en/stable/api/functions/list_contains/)
"""
from daft.functions import list_contains

return list_contains(self, item)

def list_map(self, mapper: Expression) -> Expression:
"""Evaluates an expression on all elements in the list.
Expand Down
2 changes: 2 additions & 0 deletions daft/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
list_map,
explode,
list_append,
list_contains,
to_list,
)
from .llm import llm_generate
Expand Down Expand Up @@ -322,6 +323,7 @@
"list_append",
"list_bool_and",
"list_bool_or",
"list_contains",
"list_count",
"list_distinct",
"list_join",
Expand Down
46 changes: 46 additions & 0 deletions daft/functions/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,52 @@ def list_append(list_expr: Expression, other: Expression) -> Expression:
return Expression._call_builtin_scalar_fn("list_append", list_expr, other)


def list_contains(list_expr: Expression, item: Expression) -> Expression:
"""Checks if each list contains the specified item.
Args:
list_expr: expression to search in
item: value or column of values to search for
Returns:
Boolean expression indicating whether each list contains the item
Examples:
>>> import daft
>>> from daft.functions import list_contains
>>>
>>> df = daft.from_pydict({"a": [[1, 2, 3], [2, 4], [1, 3, 5], []]})
>>> df.where(list_contains(df["a"], 3)).show()
╭─────────────╮
│ a │
│ --- │
│ List[Int64] │
╞═════════════╡
│ [1, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [1, 3, 5] │
╰─────────────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
>>> # Check against another column
>>> df2 = daft.from_pydict({"lists": [[1, 2], [3, 4]], "items": [1, 4]})
>>> df2.with_column("match", list_contains(df2["lists"], df2["items"])).show()
╭─────────────┬───────┬───────╮
│ lists ┆ items ┆ match │
│ --- ┆ --- ┆ --- │
│ List[Int64] ┆ Int64 ┆ Bool │
╞═════════════╪═══════╪═══════╡
│ [1, 2] ┆ 1 ┆ true │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [3, 4] ┆ 4 ┆ true │
╰─────────────┴───────┴───────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
"""
return Expression._call_builtin_scalar_fn("list_contains", list_expr, item)


def to_list(*items: Expression) -> Expression:
"""Constructs a list from the item expressions.
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,9 @@ def get(self, idx: Series, default: Series) -> Series:
def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series:
return self._eval_expressions("list_sort", desc=desc, nulls_first=nulls_first)

def contains(self, item: Series) -> Series:
return self._eval_expressions("list_contains", item=item)


class SeriesMapNamespace(SeriesNamespace):
def get(self, key: Series) -> Series:
Expand Down
79 changes: 79 additions & 0 deletions src/daft-functions-list/src/contains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use common_error::{DaftResult, ensure};
use daft_core::{
datatypes::DataType,
prelude::{Field, Schema},
series::Series,
};
use daft_dsl::{
ExprRef,
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
};
use serde::{Deserialize, Serialize};

use crate::series::SeriesListExtension;

#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListContains;

#[typetag::serde]
impl ScalarUDF for ListContains {
fn name(&self) -> &'static str {
"list_contains"
}

fn call(&self, inputs: daft_dsl::functions::FunctionArgs<Series>) -> DaftResult<Series> {
let list_series = inputs.required((0, "list"))?;
let item = inputs.required((1, "item"))?;

let item = if item.len() == 1 {
&item.broadcast(list_series.len())?
} else {
item
};

ensure!(
item.len() == list_series.len(),
ValueError: "Item length must match list length"
);

list_series.list_contains(item)
}

fn get_return_field(
&self,
inputs: FunctionArgs<ExprRef>,
schema: &Schema,
) -> DaftResult<Field> {
ensure!(
inputs.len() == 2,
SchemaMismatch: "Expected 2 input args, got {}",
inputs.len()
);

let list_field = inputs.required((0, "list"))?.to_field(schema)?;
let item_field = inputs.required((1, "item"))?.to_field(schema)?;

ensure!(
list_field.dtype.is_list() || list_field.dtype.is_fixed_size_list(),
TypeError: "First argument must be a list, got {}",
list_field.dtype
);

let list_element_field = list_field.to_exploded_field()?;
if !list_element_field.dtype.is_null() {
ensure!(
list_element_field.dtype == item_field.dtype || item_field.dtype.is_null(),
TypeError: "Cannot search for item of type {} in list of type {}",
item_field.dtype,
list_element_field.dtype
);
}

Ok(Field::new(list_field.name, DataType::Boolean))
}
}

#[must_use]
pub fn list_contains(list_expr: ExprRef, item: ExprRef) -> ExprRef {
ScalarFn::builtin(ListContains, vec![list_expr, item]).into()
}
72 changes: 71 additions & 1 deletion src/daft-functions-list/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{iter::repeat_n, sync::Arc};

use arrow::array::{BooleanBufferBuilder, make_comparator};
use arrow::array::{BooleanBufferBuilder, BooleanBuilder, make_comparator};
use common_error::DaftResult;
use daft_arrow::offset::{Offsets, OffsetsBuffer};
use daft_core::{
Expand Down Expand Up @@ -32,6 +32,7 @@ pub trait ListArrayExtension: Sized {
fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult<Self>;
fn list_bool_and(&self) -> DaftResult<BooleanArray>;
fn list_bool_or(&self) -> DaftResult<BooleanArray>;
fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray>;
}

pub trait ListArrayAggExtension: Sized {
Expand Down Expand Up @@ -472,6 +473,71 @@ impl ListArrayExtension for ListArray {
.rename(self.name())
.with_nulls(Some(null_buffer))
}

fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray> {
let list_nulls = self.nulls();
let mut builder = BooleanBuilder::new();
let field = Field::new(self.name(), DataType::Boolean);

if self.flat_child.data_type() == &DataType::Null {
for list_idx in 0..self.len() {
let valid = list_nulls.is_none_or(|nulls| nulls.is_valid(list_idx))
&& item.is_valid(list_idx);
if valid {
builder.append_value(false);
} else {
builder.append_null();
}
}

let arrow_array = Arc::new(builder.finish());
return BooleanArray::from_arrow(field, arrow_array);
}

let item = item.cast(self.flat_child.data_type())?;
let item_hashes = item.hash(None)?;
let child_hashes = self.flat_child.hash(None)?;

let child_arrow = self.flat_child.to_arrow()?;
let item_arrow = item.to_arrow()?;
let comparator = make_comparator(
child_arrow.as_ref(),
item_arrow.as_ref(),
Default::default(),
)
.unwrap();

for (list_idx, range) in self.offsets().ranges().enumerate() {
if list_nulls.is_some_and(|nulls| nulls.is_null(list_idx)) {
builder.append_null();
continue;
}

if !item.is_valid(list_idx) {
builder.append_null();
continue;
}

let item_hash = item_hashes.get(list_idx).unwrap();
let mut found = false;
for elem_idx in range {
let elem_idx = elem_idx as usize;
if child_arrow.is_null(elem_idx) {
continue;
}
let elem_hash = child_hashes.get(elem_idx).unwrap();
if elem_hash == item_hash && comparator(elem_idx, list_idx).is_eq() {
found = true;
break;
}
}

builder.append_value(found);
}

let arrow_array = Arc::new(builder.finish());
BooleanArray::from_arrow(field, arrow_array)
}
}

impl ListArrayExtension for FixedSizeListArray {
Expand All @@ -483,6 +549,10 @@ impl ListArrayExtension for FixedSizeListArray {
self.to_list().list_bool_or()
}

fn list_contains(&self, item: &Series) -> DaftResult<BooleanArray> {
self.to_list().list_contains(item)
}

fn value_counts(&self) -> DaftResult<MapArray> {
self.to_list().value_counts()
}
Expand Down
3 changes: 3 additions & 0 deletions src/daft-functions-list/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod append;
mod bool_and;
mod bool_or;
mod chunk;
mod contains;
mod count;
mod count_distinct;
mod distinct;
Expand All @@ -21,6 +22,7 @@ pub use append::{ListAppend, list_append as append};
pub use bool_and::{ListBoolAnd, list_bool_and as bool_and};
pub use bool_or::{ListBoolOr, list_bool_or as bool_or};
pub use chunk::{ListChunk, list_chunk as chunk};
pub use contains::{ListContains, list_contains as contains};
pub use count::ListCount;
pub use count_distinct::{ListCountDistinct, list_count_distinct as count_distinct};
pub use distinct::{ListDistinct, list_distinct as distinct};
Expand Down Expand Up @@ -50,6 +52,7 @@ impl FunctionModule for ListFunctions {
parent.add_fn(ListBoolAnd);
parent.add_fn(ListBoolOr);
parent.add_fn(ListChunk);
parent.add_fn(ListContains);
parent.add_fn(ListCount);
parent.add_fn(ListCountDistinct);
parent.add_fn(ListDistinct);
Expand Down
13 changes: 13 additions & 0 deletions src/daft-functions-list/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized {
fn list_fill(&self, num: &Int64Array) -> DaftResult<Self>;
fn list_distinct(&self) -> DaftResult<Self>;
fn list_append(&self, other: &Self) -> DaftResult<Self>;
fn list_contains(&self, item: &Self) -> DaftResult<Self>;
}

impl SeriesListExtension for Series {
Expand Down Expand Up @@ -362,4 +363,16 @@ impl SeriesListExtension for Series {

Ok(list_array.into_series())
}

fn list_contains(&self, item: &Self) -> DaftResult<Self> {
match self.data_type() {
DataType::List(_) => Ok(self.list()?.list_contains(item)?.into_series()),
DataType::FixedSizeList(..) => {
Ok(self.fixed_size_list()?.list_contains(item)?.into_series())
}
dt => Err(DaftError::TypeError(format!(
"List contains not implemented for {dt}"
))),
}
}
}
67 changes: 67 additions & 0 deletions tests/recordbatch/list/test_list_contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

import pytest

from daft import col
from daft.datatype import DataType
from daft.recordbatch import MicroPartition


def test_list_contains_scalar_and_broadcast():
table = MicroPartition.from_pydict({"a": [[1, 2, 3], [4, 5, 6], [], [1]]})
result = table.eval_expression_list([col("a").list_contains(1)])
assert result.to_pydict()["a"] == [True, False, False, True]


def test_list_contains_column_items():
table = MicroPartition.from_pydict({"lists": [[1, 2], [3, 4], [5, 6]], "items": [1, 4, 7]})
result = table.eval_expression_list([col("lists").list_contains(col("items"))])
assert result.to_pydict()["lists"] == [True, True, False]


def test_list_contains_null_list_and_item():
table = MicroPartition.from_pydict({"lists": [[1, 2], None, [3, 4]], "items": [2, 2, None]})
result = table.eval_expression_list([col("lists").list_contains(col("items"))])
assert result.to_pydict()["lists"] == [True, None, None]


@pytest.mark.parametrize(
"data,search_value,expected",
[
pytest.param([[1, 2, 3], [4, 5], [1]], 1, [True, False, True], id="int"),
pytest.param([[1.0, 2.0], [3.0, 4.0]], 2.0, [True, False], id="float"),
pytest.param([["a", "b"], ["c", "d"]], "a", [True, False], id="string"),
pytest.param([[True, False], [False, False]], True, [True, False], id="bool"),
],
)
def test_list_contains_types(data, search_value, expected):
table = MicroPartition.from_pydict({"a": data})
result = table.eval_expression_list([col("a").list_contains(search_value)])
assert result.to_pydict()["a"] == expected


def test_list_contains_nulls_in_list():
table = MicroPartition.from_pydict({"a": [[1, None, 3], [None, None], [None, 5]]})
result = table.eval_expression_list([col("a").list_contains(3)])
assert result.to_pydict()["a"] == [True, False, False]


def test_list_contains_list_null_dtype():
table = MicroPartition.from_pydict({"a": [[None, None], [None], []], "items": [1, None, 1]})
result = table.eval_expression_list([col("a").list_contains(col("items"))])
assert result.to_pydict()["a"] == [False, None, False]


def test_fixed_size_list_contains():
table = MicroPartition.from_pydict({"col": [["a", "b"], ["c", "d"], ["a", "c"]]})
fixed_dtype = DataType.fixed_size_list(DataType.string(), 2)
table = table.eval_expression_list([col("col").cast(fixed_dtype)])

result = table.eval_expression_list([col("col").list_contains("a")])
assert result.to_pydict()["col"] == [True, False, True]


def test_list_contains_varying_lengths():
table = MicroPartition.from_pydict({"a": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]})
result = table.eval_expression_list([col("a").list_contains(3)])
assert result.to_pydict()["a"] == [False, False, True, True]
Loading
Loading