Skip to content

Commit 07df868

Browse files
authored
add arrays contains to sqlite (#860)
1 parent aed4ae7 commit 07df868

File tree

8 files changed

+135
-5
lines changed

8 files changed

+135
-5
lines changed

examples/get_started/common_sql_functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def num_chars_udf(file):
99
return ([],)
1010

1111

12-
dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/")
12+
dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/", anon=True)
1313
dc.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select(
1414
"file.path", "num_chars"
1515
).show(5)
@@ -32,6 +32,12 @@ def num_chars_udf(file):
3232
.show(5)
3333
)
3434

35+
parts = string.split(path.name(C("file.path")), ".")
36+
chain = dc.mutate(
37+
isdog=array.contains(parts, "dog"),
38+
iscat=array.contains(parts, "cat"),
39+
)
40+
chain.select("file.path", "isdog", "iscat").show(5)
3541

3642
chain = dc.mutate(
3743
a=array.length(string.split("file.path", "/")),
@@ -79,6 +85,15 @@ def num_chars_udf(file):
7985
3 dogs-and-cats/cat.10.json cat.10 json
8086
4 dogs-and-cats/cat.100.jpg cat.100 jpg
8187
88+
[Limited by 5 rows]
89+
file isdog iscat
90+
path
91+
0 dogs-and-cats/cat.1.jpg 0 1
92+
1 dogs-and-cats/cat.1.json 0 1
93+
2 dogs-and-cats/cat.10.jpg 0 1
94+
3 dogs-and-cats/cat.10.json 0 1
95+
4 dogs-and-cats/cat.100.jpg 0 1
96+
8297
[Limited by 5 rows]
8398
Processed: 400 rows [00:00, 16496.93 rows/s]
8499
a b greatest least

src/datachain/func/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
row_number,
1616
sum,
1717
)
18-
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
18+
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
1919
from .conditional import case, greatest, ifelse, isnone, least
2020
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
2121
from .random import rand
@@ -34,6 +34,7 @@
3434
"case",
3535
"collect",
3636
"concat",
37+
"contains",
3738
"cosine_distance",
3839
"count",
3940
"dense_rank",

src/datachain/func/array.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from typing import Union
2+
from typing import Any, Union
33

44
from datachain.sql.functions import array
55

@@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func:
140140
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
141141

142142

143+
def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
144+
"""
145+
Checks whether the `arr` array has the `elem` element.
146+
147+
Args:
148+
arr (str | Sequence | Func): Array to check for the element.
149+
If a string is provided, it is assumed to be the name of the array column.
150+
If a sequence is provided, it is assumed to be an array of values.
151+
If a Func is provided, it is assumed to be a function returning an array.
152+
elem (Any): Element to check for in the array.
153+
154+
Returns:
155+
Func: A Func object that represents the contains function. Result of the
156+
function will be 1 if the element is present in the array, and 0 otherwise.
157+
158+
Example:
159+
```py
160+
dc.mutate(
161+
contains1=func.array.contains("signal.values", 3),
162+
contains2=func.array.contains([1, 2, 3, 4, 5], 7),
163+
)
164+
```
165+
"""
166+
167+
def inner(arg):
168+
is_json = type(elem) in [list, dict]
169+
return array.contains(arg, elem, is_json)
170+
171+
if isinstance(arr, (str, Func)):
172+
cols = [arr]
173+
args = None
174+
else:
175+
cols = None
176+
args = [arr]
177+
178+
return Func("contains", inner=inner, cols=cols, args=args, result_type=int)
179+
180+
143181
def sip_hash_64(arg: Union[str, Sequence]) -> Func:
144182
"""
145183
Computes the SipHash-64 hash of the array.

src/datachain/sql/functions/array.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.sql.functions import GenericFunction
22

3-
from datachain.sql.types import Float, Int64
3+
from datachain.sql.types import Boolean, Float, Int64
44
from datachain.sql.utils import compiler_not_implemented
55

66

@@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801
3737
inherit_cache = True
3838

3939

40+
class contains(GenericFunction): # noqa: N801
41+
"""
42+
Checks if element is in the array.
43+
"""
44+
45+
type = Boolean()
46+
package = "array"
47+
name = "contains"
48+
inherit_cache = True
49+
50+
4051
class sip_hash_64(GenericFunction): # noqa: N801
4152
"""
4253
Computes the SipHash-64 hash of the array.
@@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801
5162
compiler_not_implemented(cosine_distance)
5263
compiler_not_implemented(euclidean_distance)
5364
compiler_not_implemented(length)
65+
compiler_not_implemented(contains)
5466
compiler_not_implemented(sip_hash_64)

src/datachain/sql/sqlite/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def setup():
8787
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
8888
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
8989
compiles(array.length, "sqlite")(compile_array_length)
90+
compiles(array.contains, "sqlite")(compile_array_contains)
9091
compiles(string.length, "sqlite")(compile_string_length)
9192
compiles(string.split, "sqlite")(compile_string_split)
9293
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
@@ -269,13 +270,16 @@ def create_string_functions(conn):
269270

270271
_registered_function_creators["string_functions"] = create_string_functions
271272

272-
has_json_extension = functions_exist(["json_array_length"])
273+
has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
273274
if not has_json_extension:
274275

275276
def create_json_functions(conn):
276277
conn.create_function(
277278
"json_array_length", 1, py_json_array_length, deterministic=True
278279
)
280+
conn.create_function(
281+
"json_array_contains", 3, py_json_array_contains, deterministic=True
282+
)
279283

280284
_registered_function_creators["json_functions"] = create_json_functions
281285

@@ -428,10 +432,22 @@ def py_json_array_length(arr):
428432
return len(orjson.loads(arr))
429433

430434

435+
def py_json_array_contains(arr, value, is_json):
436+
if is_json:
437+
value = orjson.loads(value)
438+
return value in orjson.loads(arr)
439+
440+
431441
def compile_array_length(element, compiler, **kwargs):
432442
return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)
433443

434444

445+
def compile_array_contains(element, compiler, **kwargs):
446+
return compiler.process(
447+
func.json_array_contains(*element.clauses.clauses), **kwargs
448+
)
449+
450+
435451
def compile_string_length(element, compiler, **kwargs):
436452
return compiler.process(func.length(*element.clauses.clauses), **kwargs)
437453

src/datachain/sql/sqlite/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def adapt_array(arr):
3131
return orjson.dumps(arr).decode("utf-8")
3232

3333

34+
def adapt_dict(dct):
35+
return orjson.dumps(dct).decode("utf-8")
36+
37+
3438
def convert_array(arr):
3539
return orjson.loads(arr)
3640

@@ -52,6 +56,7 @@ def adapt_np_generic(val):
5256

5357
def register_type_converters():
5458
sqlite3.register_adapter(list, adapt_array)
59+
sqlite3.register_adapter(dict, adapt_dict)
5560
sqlite3.register_converter("ARRAY", convert_array)
5661
if numpy_imported:
5762
sqlite3.register_adapter(np.ndarray, adapt_np_array)

tests/unit/sql/test_array.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,24 @@ def test_length(warehouse):
6565
assert result == ((4, 5, 2),)
6666

6767

68+
def test_contains(warehouse):
69+
query = select(
70+
func.contains(["abc", "def", "g", "hi"], "abc").label("contains1"),
71+
func.contains(["abc", "def", "g", "hi"], "cdf").label("contains2"),
72+
func.contains([3.0, 5.0, 1.0, 6.0, 1.0], 1.0).label("contains3"),
73+
func.contains([[1, None, 3], [4, 5, 6]], [1, None, 3]).label("contains4"),
74+
# Not supported yet by CH, need to add it later + some Pydantic model as
75+
# an input:
76+
# func.contains(
77+
# [{"c": 1, "a": True}, {"b": False}], {"a": True, "c": 1}
78+
# ).label("contains5"),
79+
func.contains([1, None, 3], None).label("contains6"),
80+
func.contains([1, True, 3], True).label("contains7"),
81+
)
82+
result = tuple(warehouse.db.execute(query))
83+
assert result == ((1, 0, 1, 1, 1, 1),)
84+
85+
6886
def test_length_on_split(warehouse):
6987
query = select(
7088
func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")),

tests/unit/test_func.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
isnone,
1212
literal,
1313
)
14+
from datachain.func.array import contains
1415
from datachain.func.random import rand
1516
from datachain.func.string import length as strlen
1617
from datachain.lib.signal_schema import SignalSchema
@@ -797,3 +798,27 @@ def test_isnone_with_ifelse_mutate(col):
797798
res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
798799
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
799800
assert res.schema["test"] is str
801+
802+
803+
def test_array_contains():
804+
dc = DataChain.from_values(
805+
arr=[list(range(1, i)) * i for i in range(2, 7)],
806+
val=list(range(2, 7)),
807+
)
808+
809+
assert list(dc.mutate(res=contains("arr", 3)).order_by("val").collect("res")) == [
810+
0,
811+
0,
812+
1,
813+
1,
814+
1,
815+
]
816+
assert list(
817+
dc.mutate(res=contains(C("arr"), 3)).order_by("val").collect("res")
818+
) == [0, 0, 1, 1, 1]
819+
assert list(
820+
dc.mutate(res=contains(C("arr"), 10)).order_by("val").collect("res")
821+
) == [0, 0, 0, 0, 0]
822+
assert list(
823+
dc.mutate(res=contains(C("arr"), None)).order_by("val").collect("res")
824+
) == [0, 0, 0, 0, 0]

0 commit comments

Comments
 (0)