Skip to content

Commit 71e553a

Browse files
fix: Aggregation and Union __bool__ always returning True (#1234)
The __bool__ methods were returning bool(cursor) instead of bool(cursor.fetchone()[0]), causing them to always return True regardless of whether the query result was empty. Added tests for Aggregation and Union __bool__ behavior: - test_aggregation_bool_with_results - test_aggregation_bool_empty - test_aggregation_bool_matches_len - test_union_bool_with_results - test_union_bool_empty - test_union_bool_matches_len Fixes #1234 Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent c1b36f0 commit 71e553a

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/datajoint/expression.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False):
728728
try:
729729
import polars
730730
except ImportError:
731-
raise ImportError("polars is required for to_polars(). " "Install with: pip install datajoint[polars]")
731+
raise ImportError("polars is required for to_polars(). Install with: pip install datajoint[polars]")
732732
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
733733
return polars.DataFrame(dicts)
734734

@@ -747,7 +747,7 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False):
747747
try:
748748
import pyarrow
749749
except ImportError:
750-
raise ImportError("pyarrow is required for to_arrow(). " "Install with: pip install datajoint[arrow]")
750+
raise ImportError("pyarrow is required for to_arrow(). Install with: pip install datajoint[arrow]")
751751
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
752752
if not dicts:
753753
return pyarrow.table({})
@@ -1039,7 +1039,7 @@ def __len__(self):
10391039
).fetchone()[0]
10401040

10411041
def __bool__(self):
1042-
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())))
1042+
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])
10431043

10441044

10451045
class Union(QueryExpression):
@@ -1101,7 +1101,7 @@ def __len__(self):
11011101
).fetchone()[0]
11021102

11031103
def __bool__(self):
1104-
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())))
1104+
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])
11051105

11061106

11071107
class U:

tests/integration/test_aggr_regressions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,70 @@ def test_extend_invalid_raises_error(schema_uuid):
180180
with pytest.raises(DataJointError) as exc_info:
181181
Topic.extend(Item)
182182
assert "left operand to determine" in str(exc_info.value).lower()
183+
184+
185+
class TestBoolMethod:
186+
"""
187+
Tests for __bool__ method on Aggregation and Union (issue #1234).
188+
189+
bool(query) should return True if query has rows, False if empty.
190+
"""
191+
192+
def test_aggregation_bool_with_results(self, schema_aggr_reg_with_abx):
193+
"""Aggregation with results should be truthy."""
194+
A.insert([(1,), (2,), (3,)])
195+
B.insert([(1, 10), (1, 20), (2, 30)])
196+
aggr = A.aggr(B, count="count(id2)")
197+
assert bool(aggr) is True
198+
assert len(aggr) > 0
199+
200+
def test_aggregation_bool_empty(self, schema_aggr_reg_with_abx):
201+
"""Aggregation with no results should be falsy."""
202+
A.insert([(1,), (2,), (3,)])
203+
B.insert([(1, 10), (1, 20), (2, 30)])
204+
# Restrict to non-existent entry
205+
aggr = (A & "id=999").aggr(B, count="count(id2)")
206+
assert bool(aggr) is False
207+
assert len(aggr) == 0
208+
209+
def test_aggregation_bool_matches_len(self, schema_aggr_reg_with_abx):
210+
"""bool(aggr) should equal len(aggr) > 0."""
211+
A.insert([(10,), (20,)])
212+
B.insert([(10, 100)])
213+
# With results
214+
aggr_has = A.aggr(B, count="count(id2)")
215+
assert bool(aggr_has) == (len(aggr_has) > 0)
216+
# Without results
217+
aggr_empty = (A & "id=999").aggr(B, count="count(id2)")
218+
assert bool(aggr_empty) == (len(aggr_empty) > 0)
219+
220+
def test_union_bool_with_results(self, schema_aggr_reg_with_abx):
221+
"""Union with results should be truthy."""
222+
A.insert([(100,), (200,)])
223+
B.insert([(100, 1), (200, 2)])
224+
q1 = B & "id=100"
225+
q2 = B & "id=200"
226+
union = q1 + q2
227+
assert bool(union) is True
228+
assert len(union) > 0
229+
230+
def test_union_bool_empty(self, schema_aggr_reg_with_abx):
231+
"""Union with no results should be falsy."""
232+
A.insert([(100,), (200,)])
233+
B.insert([(100, 1), (200, 2)])
234+
q1 = B & "id=999"
235+
q2 = B & "id=998"
236+
union = q1 + q2
237+
assert bool(union) is False
238+
assert len(union) == 0
239+
240+
def test_union_bool_matches_len(self, schema_aggr_reg_with_abx):
241+
"""bool(union) should equal len(union) > 0."""
242+
A.insert([(100,), (200,)])
243+
B.insert([(100, 1)])
244+
# With results
245+
union_has = (B & "id=100") + (B & "id=100")
246+
assert bool(union_has) == (len(union_has) > 0)
247+
# Without results
248+
union_empty = (B & "id=999") + (B & "id=998")
249+
assert bool(union_empty) == (len(union_empty) > 0)

0 commit comments

Comments
 (0)