Skip to content

Commit 21b0c13

Browse files
authored
Addition of in Operator For QueryCondition (#1214)
1 parent d9b58ca commit 21b0c13

File tree

3 files changed

+136
-21
lines changed

3 files changed

+136
-21
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# In Progress
2+
3+
## API Changes
4+
* Addition of `in` operator for `QueryCondition` [#1214](https://github.com/TileDB-Inc/TileDB-Py/pull/1214)
5+
16
# TileDB-Py 0.16.3 Release Notes
27

38
## Packaging Notes

tiledb/query_condition.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@ class QueryCondition:
2323
Class representing a TileDB query condition object for attribute filtering
2424
pushdown.
2525
26+
A query condition is set with a string representing an expression
27+
as defined by the grammar below. A more straight forward example of usage is
28+
given beneath.
29+
2630
When querying a sparse array, only the values that satisfy the given
2731
condition are returned (coupled with their associated coordinates). An example
2832
may be found in `examples/query_condition_sparse.py`.
2933
3034
For dense arrays, the given shape of the query matches the shape of the output
3135
array. Values that DO NOT satisfy the given condition are filled with the
3236
TileDB default fill value. Different attribute types have different default
33-
fill values as outlined here (https://docs.tiledb.com/main/background/internal-mechanics/writing#default-fill-values). An example may be found in `examples/query_condition_dense.py`.
34-
35-
Set the query condition with a string representing an expression
36-
as defined by the grammar below. A more straight forward example of usage is
37-
given beneath.
37+
fill values as outlined here
38+
(https://docs.tiledb.com/main/background/internal-mechanics/writing#default-fill-values).
39+
An example may be found in `examples/query_condition_dense.py`.
3840
3941
**BNF:**
4042
@@ -56,15 +58,27 @@ class QueryCondition:
5658
5759
We intend to support ``not`` in future releases.
5860
59-
A Boolean expression contains a comparison operator. The operator works on a
61+
A Boolean expression may either be a comparison expression or membership
62+
expression.
63+
64+
``bool_expr ::= compare_expr | member_expr``
65+
66+
A comparison expression contains a comparison operator. The operator works on a
6067
TileDB attribute name and value.
6168
62-
``bool_expr ::= attr compare_op val | val compare_op attr | val compare_op attr compare_op val``
69+
``compare_expr ::= attr compare_op val
70+
| val compare_op attr
71+
| val compare_op attr compare_op val``
6372
6473
All comparison operators are supported.
6574
6675
``compare_op ::= < | > | <= | >= | == | !=``
6776
77+
A memership expression contains the membership operator, ``in``. The operator
78+
works on a TileDB attribute and list of values.
79+
80+
``member_expr ::= attr in <list>``
81+
6882
TileDB attribute names are Python valid variables or a ``attr()`` casted string.
6983
7084
``attr ::= <variable> | attr(<str>)``
@@ -79,9 +93,15 @@ class QueryCondition:
7993
>>> # Select cells where the attribute values for `foo` are less than 5
8094
>>> # and `bar` equal to string "asdf".
8195
>>> # Note precedence is equivalent to:
82-
>>> # (foo > 5 or ('asdf' == attr('b a r') and baz <= val(1.0)))
83-
>>> qc = QueryCondition("foo > 5 or 'asdf' == attr('b a r') and baz <= val(1.0)")
96+
>>> # tiledb.QueryCondition("foo > 5 or ('asdf' == attr('b a r') and baz <= val(1.0))")
97+
>>> qc = tiledb.QueryCondition("foo > 5 or 'asdf' == attr('b a r') and baz <= val(1.0)")
8498
>>> A.query(attr_cond=qc)
99+
>>>
100+
>>> # Select cells where the attribute values for `foo` are equal to
101+
>>> # 1, 2, or 3.
102+
>>> # Note this is equivalent to:
103+
>>> # tiledb.QueryCondition("foo == 1 or foo == 2 or foo == 3")
104+
>>> A.query(attr_cond=tiledb.QueryCondition("foo in [1, 2, 3]"))
85105
"""
86106

87107
expression: str
@@ -151,21 +171,52 @@ def visit_Eq(self, node):
151171
def visit_NotEq(self, node):
152172
return qc.TILEDB_NE
153173

174+
def visit_In(self, node):
175+
return node
176+
177+
def visit_List(self, node):
178+
return list(node.elts)
179+
154180
def visit_Compare(self, node: Type[ast.Compare]) -> PyQueryCondition:
155-
result = self.aux_visit_Compare(
156-
self.visit(node.left),
157-
self.visit(node.ops[0]),
158-
self.visit(node.comparators[0]),
159-
)
160-
161-
# Handling cases val < attr < val
162-
for lhs, op, rhs in zip(
163-
node.comparators[:-1], node.ops[1:], node.comparators[1:]
181+
operator = self.visit(node.ops[0])
182+
183+
if operator in (
184+
qc.TILEDB_GT,
185+
qc.TILEDB_GE,
186+
qc.TILEDB_LT,
187+
qc.TILEDB_LE,
188+
qc.TILEDB_EQ,
189+
qc.TILEDB_NE,
164190
):
165-
value = self.aux_visit_Compare(
166-
self.visit(lhs), self.visit(op), self.visit(rhs)
191+
result = self.aux_visit_Compare(
192+
self.visit(node.left),
193+
operator,
194+
self.visit(node.comparators[0]),
195+
)
196+
197+
# Handling cases val < attr < val
198+
for lhs, op, rhs in zip(
199+
node.comparators[:-1], node.ops[1:], node.comparators[1:]
200+
):
201+
value = self.aux_visit_Compare(
202+
self.visit(lhs), self.visit(op), self.visit(rhs)
203+
)
204+
result = result.combine(value, qc.TILEDB_AND)
205+
elif isinstance(operator, ast.In):
206+
rhs = node.comparators[0]
207+
if not isinstance(rhs, ast.List):
208+
raise tiledb.TileDBError(
209+
f"`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`"
210+
)
211+
212+
consts = self.visit(rhs)
213+
result = self.aux_visit_Compare(
214+
self.visit(node.left), qc.TILEDB_EQ, consts[0]
167215
)
168-
result = result.combine(value, qc.TILEDB_AND)
216+
217+
for val in consts[1:]:
218+
value = self.aux_visit_Compare(self.visit(node.left), qc.TILEDB_EQ, val)
219+
result = result.combine(value, qc.TILEDB_OR)
169220

170221
return result
171222

tiledb/tests/test_query_condition.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,62 @@ def test_01(self):
548548
qc = tiledb.QueryCondition("a == 1 or (b == 1 and c == 1)")
549549
result = A.query(attr_cond=qc)[:]
550550
assert all(result["a"] | result["b"] & result["c"])
551+
552+
def test_in_operator_sparse(self):
553+
with tiledb.open(self.create_input_array_UIDSA(sparse=True)) as A:
554+
qc = tiledb.QueryCondition("U in [1, 2, 3]")
555+
result = A.query(attr_cond=qc, attrs=["U"])[:]
556+
for val in result["U"]:
557+
assert val in [1, 2, 3]
558+
559+
qc = tiledb.QueryCondition("S in ['a', 'e', 'i', 'o', 'u']")
560+
result = A.query(attr_cond=qc, attrs=["S"])[:]
561+
for val in result["S"]:
562+
assert val in [b"a", b"e", b"i", b"o", b"u"]
563+
564+
qc = tiledb.QueryCondition(
565+
"S in ['a', 'e', 'i', 'o', 'u'] and U in [5, 6, 7]"
566+
)
567+
result = A.query(attr_cond=qc)[:]
568+
for val in result["U"]:
569+
assert val in [5, 6, 7]
570+
for val in result["S"]:
571+
assert val in [b"a", b"e", b"i", b"o", b"u"]
572+
573+
result = A.query(attr_cond=tiledb.QueryCondition("U in [8]"))[:]
574+
for val in result["U"]:
575+
assert val == 8
576+
577+
result = A.query(attr_cond=tiledb.QueryCondition("S in ['8']"))[:]
578+
assert len(result["S"]) == 0
579+
580+
def test_in_operator_dense(self):
581+
with tiledb.open(self.create_input_array_UIDSA(sparse=False)) as A:
582+
U_mask = A.attr("U").fill
583+
S_mask = A.attr("S").fill
584+
585+
qc = tiledb.QueryCondition("U in [1, 2, 3]")
586+
result = A.query(attr_cond=qc, attrs=["U"])[:]
587+
for val in self.filter_sparse(result["U"], U_mask):
588+
assert val in [1, 2, 3]
589+
590+
qc = tiledb.QueryCondition("S in ['a', 'e', 'i', 'o', 'u']")
591+
result = A.query(attr_cond=qc, attrs=["S"])[:]
592+
for val in self.filter_sparse(result["S"], S_mask):
593+
assert val in [b"a", b"e", b"i", b"o", b"u"]
594+
595+
qc = tiledb.QueryCondition(
596+
"S in ['a', 'e', 'i', 'o', 'u'] and U in [5, 6, 7]"
597+
)
598+
result = A.query(attr_cond=qc)[:]
599+
for val in self.filter_sparse(result["U"], U_mask):
600+
assert val in [5, 6, 7]
601+
for val in self.filter_sparse(result["S"], S_mask):
602+
assert val in [b"a", b"e", b"i", b"o", b"u"]
603+
604+
result = A.query(attr_cond=tiledb.QueryCondition("U in [8]"))[:]
605+
for val in self.filter_sparse(result["U"], U_mask):
606+
assert val == 8
607+
608+
result = A.query(attr_cond=tiledb.QueryCondition("S in ['8']"))[:]
609+
assert len(self.filter_sparse(result["S"], S_mask)) == 0

0 commit comments

Comments
 (0)