Skip to content

Commit d5c275d

Browse files
Merge remote-tracking branch 'origin/claude/modern-fetch-api' into docs-2.0-migration
2 parents 7820702 + d1dafdc commit d5c275d

File tree

4 files changed

+42
-20
lines changed

4 files changed

+42
-20
lines changed

src/datajoint/expression.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def join(self, other, semantic_check=True, left=False, allow_nullable_pk=False):
306306
:param allow_nullable_pk: If True, bypass the left join constraint that requires
307307
self to determine other. When bypassed, the result PK is the union of both
308308
operands' PKs, and PK attributes from the right operand could be NULL.
309-
Used internally by aggregation with keep_all_rows=True.
309+
Used internally by aggregation when exclude_nonmatching=False.
310310
:return: The joined QueryExpression
311311
312312
a * b is short for a.join(b)
@@ -538,21 +538,33 @@ def proj(self, *attributes, **named_attributes):
538538
)
539539
return result
540540

541-
def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes):
541+
def aggr(self, group, *attributes, exclude_nonmatching=False, **named_attributes):
542542
"""
543-
Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
544-
has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
543+
Aggregation/grouping operation, similar to proj but with computations over a grouped relation.
545544
546-
:param group: The query expression to be aggregated.
547-
:param keep_all_rows: True=keep all the rows from self. False=keep only rows that match entries in group.
545+
By default, keeps all rows from self (like proj). Use exclude_nonmatching=True to
546+
keep only rows that have matches in group.
547+
548+
:param group: The query expression to be aggregated.
549+
:param exclude_nonmatching: If True, exclude rows from self that have no matching
550+
entries in group (INNER JOIN). Default False keeps all rows (LEFT JOIN).
548551
:param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
549552
:return: The derived query expression
553+
554+
Example::
555+
556+
# Count sessions per subject (keeps all subjects, even those with 0 sessions)
557+
Subject.aggr(Session, n="count(*)")
558+
559+
# Count sessions per subject (only subjects with at least one session)
560+
Subject.aggr(Session, n="count(*)", exclude_nonmatching=True)
550561
"""
551562
if Ellipsis in attributes:
552563
# expand ellipsis to include only attributes from the left table
553564
attributes = set(attributes)
554565
attributes.discard(Ellipsis)
555566
attributes.update(self.heading.secondary_attributes)
567+
keep_all_rows = not exclude_nonmatching
556568
return Aggregation.create(self, group=group, keep_all_rows=keep_all_rows).proj(*attributes, **named_attributes)
557569

558570
aggregate = aggr # alias for aggr
@@ -1170,12 +1182,14 @@ def aggr(self, group, **named_attributes):
11701182
Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
11711183
has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
11721184
1185+
Note: exclude_nonmatching is always True for dj.U (cannot keep all rows from infinite set).
1186+
11731187
:param group: The query expression to be aggregated.
11741188
:param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
11751189
:return: The derived query expression
11761190
"""
1177-
if named_attributes.get("keep_all_rows", False):
1178-
raise DataJointError("Cannot set keep_all_rows=True when aggregating on a universal set.")
1191+
if named_attributes.pop("exclude_nonmatching", True) is False:
1192+
raise DataJointError("Cannot set exclude_nonmatching=False when aggregating on a universal set.")
11791193

11801194
if inspect.isclass(group) and issubclass(group, QueryExpression):
11811195
group = group()

src/datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# version bump auto managed by Github Actions:
22
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
33
# manually set this version will be eventually overwritten by the above actions
4-
__version__ = "2.0.0a15"
4+
__version__ = "2.0.0a16"

tests/integration/test_relational_operand.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,12 @@ def test_heading_repr(schema_simp_pop):
230230

231231

232232
def test_aggregate(schema_simp_pop):
233-
x = B().aggregate(B.C())
233+
# With exclude_nonmatching=True, only rows with matches are kept (INNER JOIN)
234+
x = B().aggregate(B.C(), exclude_nonmatching=True)
234235
assert len(x) == len(B() & B.C())
235236

236-
x = B().aggregate(B.C(), keep_all_rows=True)
237+
# Default behavior now keeps all rows (LEFT JOIN)
238+
x = B().aggregate(B.C())
237239
assert len(x) == len(B()) # test LEFT join
238240

239241
assert len((x & "id_b=0").to_arrays()) == len(B() & "id_b=0") # test restricted aggregation
@@ -244,7 +246,6 @@ def test_aggregate(schema_simp_pop):
244246
count="count(id_c)",
245247
mean="avg(value)",
246248
max="max(value)",
247-
keep_all_rows=True,
248249
)
249250
assert len(x) == len(B())
250251
y = x & "mean>0" # restricted aggregation
@@ -260,12 +261,14 @@ def test_aggregate(schema_simp_pop):
260261

261262

262263
def test_aggr(schema_simp_pop):
263-
x = B.aggr(B.C)
264+
# With exclude_nonmatching=True, only rows with matches are kept (INNER JOIN)
265+
x = B.aggr(B.C, exclude_nonmatching=True)
264266
l1 = len(x)
265267
l2 = len(B & B.C)
266268
assert l1 == l2
267269

268-
x = B().aggr(B.C(), keep_all_rows=True)
270+
# Default behavior now keeps all rows (LEFT JOIN)
271+
x = B().aggr(B.C())
269272
assert len(x) == len(B()) # test LEFT join
270273

271274
assert len((x & "id_b=0").to_arrays()) == len(B() & "id_b=0") # test restricted aggregation
@@ -276,7 +279,6 @@ def test_aggr(schema_simp_pop):
276279
count="count(id_c)",
277280
mean="avg(value)",
278281
max="max(value)",
279-
keep_all_rows=True,
280282
)
281283
assert len(x) == len(B())
282284
y = x & "mean>0" # restricted aggregation

tests/integration/test_university.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,29 @@ def test_union(schema_uni):
138138

139139

140140
def test_aggr(schema_uni):
141+
# Default: keeps all courses (some may have NULL avg_grade if no grades)
141142
avg_grade_per_course = Course.aggr(Grade * LetterGrade, avg_grade="round(avg(points), 2)")
142143
assert len(avg_grade_per_course) == 45
143144

144-
# GPA
145-
student_gpa = Student.aggr(Course * Grade * LetterGrade, gpa="round(sum(points*credits)/sum(credits), 2)")
145+
# GPA - use exclude_nonmatching=True to only include students with grades
146+
student_gpa = Student.aggr(
147+
Course * Grade * LetterGrade,
148+
gpa="round(sum(points*credits)/sum(credits), 2)",
149+
exclude_nonmatching=True,
150+
)
146151
gpa = student_gpa.to_arrays("gpa")
147-
assert len(gpa) == 261
152+
assert len(gpa) == 261 # only students with grades
148153
assert 2 < gpa.mean() < 3
149154

150155
# Sections in biology department with zero students in them
151-
section = (Section & {"dept": "BIOL"}).aggr(Enroll, n="count(student_id)", keep_all_rows=True) & "n=0"
156+
# aggr now keeps all rows by default (like proj), so sections with 0 enrollments are included
157+
section = (Section & {"dept": "BIOL"}).aggr(Enroll, n="count(student_id)") & "n=0"
152158
assert len(set(section.to_arrays("dept"))) == 1
153159
assert len(section) == 17
154160
assert bool(section)
155161

156162
# Test correct use of ellipses in a similar query
157-
section = (Section & {"dept": "BIOL"}).aggr(Grade, ..., n="count(student_id)", keep_all_rows=True) & "n>1"
163+
section = (Section & {"dept": "BIOL"}).aggr(Grade, ..., n="count(student_id)") & "n>1"
158164
assert not any(name in section.heading.names for name in Grade.heading.secondary_attributes)
159165
assert len(set(section.to_arrays("dept"))) == 1
160166
assert len(section) == 168

0 commit comments

Comments
 (0)