|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from daft import col |
| 6 | +from daft.datatype import DataType |
| 7 | +from daft.recordbatch import MicroPartition |
| 8 | + |
| 9 | + |
| 10 | +def test_list_contains_scalar_and_broadcast(): |
| 11 | + table = MicroPartition.from_pydict({"a": [[1, 2, 3], [4, 5, 6], [], [1]]}) |
| 12 | + result = table.eval_expression_list([col("a").list_contains(1)]) |
| 13 | + assert result.to_pydict()["a"] == [True, False, False, True] |
| 14 | + |
| 15 | + |
| 16 | +def test_list_contains_column_items(): |
| 17 | + table = MicroPartition.from_pydict({"lists": [[1, 2], [3, 4], [5, 6]], "items": [1, 4, 7]}) |
| 18 | + result = table.eval_expression_list([col("lists").list_contains(col("items"))]) |
| 19 | + assert result.to_pydict()["lists"] == [True, True, False] |
| 20 | + |
| 21 | + |
| 22 | +def test_list_contains_null_list_and_item(): |
| 23 | + table = MicroPartition.from_pydict({"lists": [[1, 2], None, [3, 4]], "items": [2, 2, None]}) |
| 24 | + result = table.eval_expression_list([col("lists").list_contains(col("items"))]) |
| 25 | + assert result.to_pydict()["lists"] == [True, None, None] |
| 26 | + |
| 27 | + |
| 28 | +@pytest.mark.parametrize( |
| 29 | + "data,search_value,expected", |
| 30 | + [ |
| 31 | + pytest.param([[1, 2, 3], [4, 5], [1]], 1, [True, False, True], id="int"), |
| 32 | + pytest.param([[1.0, 2.0], [3.0, 4.0]], 2.0, [True, False], id="float"), |
| 33 | + pytest.param([["a", "b"], ["c", "d"]], "a", [True, False], id="string"), |
| 34 | + pytest.param([[True, False], [False, False]], True, [True, False], id="bool"), |
| 35 | + ], |
| 36 | +) |
| 37 | +def test_list_contains_types(data, search_value, expected): |
| 38 | + table = MicroPartition.from_pydict({"a": data}) |
| 39 | + result = table.eval_expression_list([col("a").list_contains(search_value)]) |
| 40 | + assert result.to_pydict()["a"] == expected |
| 41 | + |
| 42 | + |
| 43 | +def test_list_contains_nulls_in_list(): |
| 44 | + table = MicroPartition.from_pydict({"a": [[1, None, 3], [None, None], [None, 5]]}) |
| 45 | + result = table.eval_expression_list([col("a").list_contains(3)]) |
| 46 | + assert result.to_pydict()["a"] == [True, False, False] |
| 47 | + |
| 48 | + |
| 49 | +def test_list_contains_list_null_dtype(): |
| 50 | + table = MicroPartition.from_pydict({"a": [[None, None], [None], []], "items": [1, None, 1]}) |
| 51 | + result = table.eval_expression_list([col("a").list_contains(col("items"))]) |
| 52 | + assert result.to_pydict()["a"] == [False, None, False] |
| 53 | + |
| 54 | + |
| 55 | +def test_fixed_size_list_contains(): |
| 56 | + table = MicroPartition.from_pydict({"col": [["a", "b"], ["c", "d"], ["a", "c"]]}) |
| 57 | + fixed_dtype = DataType.fixed_size_list(DataType.string(), 2) |
| 58 | + table = table.eval_expression_list([col("col").cast(fixed_dtype)]) |
| 59 | + |
| 60 | + result = table.eval_expression_list([col("col").list_contains("a")]) |
| 61 | + assert result.to_pydict()["col"] == [True, False, True] |
| 62 | + |
| 63 | + |
| 64 | +def test_list_contains_varying_lengths(): |
| 65 | + table = MicroPartition.from_pydict({"a": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]}) |
| 66 | + result = table.eval_expression_list([col("a").list_contains(3)]) |
| 67 | + assert result.to_pydict()["a"] == [False, False, True, True] |
0 commit comments