Skip to content

Commit ed9a520

Browse files
Merge pull request #1178 from datajoint/dj-top-1084-continued
dj.Top continued (#1084)
2 parents 0a49595 + 10f2b9f commit ed9a520

18 files changed

+698
-145
lines changed

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"editor.formatOnPaste": false,
3-
"editor.formatOnSave": true,
3+
"editor.formatOnSave": false,
44
"editor.rulers": [
55
94
66
],

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## Release notes
22

33
### 0.14.3 -- TBD
4+
- Added - `dj.Top` restriction ([#1024](https://github.com/datajoint/datajoint-python/issues/1024)) PR [#1084](https://github.com/datajoint/datajoint-python/pull/1084)
45
- Fixed - Added encapsulating double quotes to comply with [DOT language](https://graphviz.org/doc/info/lang.html) - PR [#1177](https://github.com/datajoint/datajoint-python/pull/1177)
56
- Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091)
67

datajoint/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"Part",
3838
"Not",
3939
"AndList",
40+
"Top",
4041
"U",
4142
"Diagram",
4243
"Di",
@@ -61,7 +62,7 @@
6162
from .schemas import VirtualModule, list_schemas
6263
from .table import Table, FreeTable
6364
from .user_tables import Manual, Lookup, Imported, Computed, Part
64-
from .expression import Not, AndList, U
65+
from .expression import Not, AndList, U, Top
6566
from .diagram import Diagram
6667
from .admin import set_password, kill
6768
from .blob import MatCell, MatStruct

datajoint/condition.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import pandas
1111
import json
1212
from .errors import DataJointError
13+
from typing import Union, List
14+
from dataclasses import dataclass
1315

1416
JSON_PATTERN = re.compile(
1517
r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$"
@@ -61,6 +63,35 @@ def append(self, restriction):
6163
super().append(restriction)
6264

6365

66+
@dataclass
67+
class Top:
68+
"""
69+
A restriction to the top entities of a query.
70+
In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET
71+
"""
72+
73+
limit: Union[int, None] = 1
74+
order_by: Union[str, List[str]] = "KEY"
75+
offset: int = 0
76+
77+
def __post_init__(self):
78+
self.order_by = self.order_by or ["KEY"]
79+
self.offset = self.offset or 0
80+
81+
if self.limit is not None and not isinstance(self.limit, int):
82+
raise TypeError("Top limit must be an integer")
83+
if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all(
84+
isinstance(r, str) for r in self.order_by
85+
):
86+
raise TypeError("Top order_by attributes must all be strings")
87+
if not isinstance(self.offset, int):
88+
raise TypeError("The offset argument must be an integer")
89+
if self.offset and self.limit is None:
90+
self.limit = 999999999999 # arbitrary large number to allow query
91+
if isinstance(self.order_by, str):
92+
self.order_by = [self.order_by]
93+
94+
6495
class Not:
6596
"""invert restriction"""
6697

datajoint/declare.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,11 @@ def format_attribute(attr):
455455
return f"`{attr}`"
456456
return f"({attr})"
457457

458-
match = re.match(
459-
r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I
460-
).groupdict()
458+
match = re.match(r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I)
459+
if match is None:
460+
raise DataJointError(f'Table definition syntax error in line "{line}"')
461+
match = match.groupdict()
462+
461463
attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"])
462464
index_sql.append(
463465
"{unique}index ({attrs})".format(

datajoint/expression.py

Lines changed: 82 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .preview import preview, repr_html
1010
from .condition import (
1111
AndList,
12+
Top,
1213
Not,
1314
make_condition,
1415
assert_join_compatibility,
@@ -52,6 +53,7 @@ class QueryExpression:
5253
_connection = None
5354
_heading = None
5455
_support = None
56+
_top = None
5557

5658
# If the query will be using distinct
5759
_distinct = False
@@ -121,17 +123,33 @@ def where_clause(self):
121123
else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction)
122124
)
123125

126+
def sorting_clauses(self):
127+
if not self._top:
128+
return ""
129+
clause = ", ".join(
130+
_wrap_attributes(
131+
_flatten_attribute_list(self.primary_key, self._top.order_by)
132+
)
133+
)
134+
if clause:
135+
clause = f" ORDER BY {clause}"
136+
if self._top.limit is not None:
137+
clause += f" LIMIT {self._top.limit}{f' OFFSET {self._top.offset}' if self._top.offset else ''}"
138+
139+
return clause
140+
124141
def make_sql(self, fields=None):
125142
"""
126143
Make the SQL SELECT statement.
127144
128145
:param fields: used to explicitly set the select attributes
129146
"""
130-
return "SELECT {distinct}{fields} FROM {from_}{where}".format(
147+
return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format(
131148
distinct="DISTINCT " if self._distinct else "",
132149
fields=self.heading.as_sql(fields or self.heading.names),
133150
from_=self.from_clause(),
134151
where=self.where_clause(),
152+
sorting=self.sorting_clauses(),
135153
)
136154

137155
# --------- query operators -----------
@@ -189,6 +207,14 @@ def restrict(self, restriction):
189207
string, or an AndList.
190208
"""
191209
attributes = set()
210+
if isinstance(restriction, Top):
211+
result = (
212+
self.make_subquery()
213+
if self._top and not self._top.__eq__(restriction)
214+
else copy.copy(self)
215+
) # make subquery to avoid overwriting existing Top
216+
result._top = restriction
217+
return result
192218
new_condition = make_condition(self, restriction, attributes)
193219
if new_condition is True:
194220
return self # restriction has no effect, return the same object
@@ -202,8 +228,10 @@ def restrict(self, restriction):
202228
pass # all ok
203229
# If the new condition uses any new attributes, a subquery is required.
204230
# However, Aggregation's HAVING statement works fine with aliased attributes.
205-
need_subquery = isinstance(self, Union) or (
206-
not isinstance(self, Aggregation) and self.heading.new_attributes
231+
need_subquery = (
232+
isinstance(self, Union)
233+
or (not isinstance(self, Aggregation) and self.heading.new_attributes)
234+
or self._top
207235
)
208236
if need_subquery:
209237
result = self.make_subquery()
@@ -539,19 +567,20 @@ def tail(self, limit=25, **fetch_kwargs):
539567

540568
def __len__(self):
541569
""":return: number of elements in the result set e.g. ``len(q1)``."""
542-
return self.connection.query(
570+
result = self.make_subquery() if self._top else copy.copy(self)
571+
return result.connection.query(
543572
"SELECT {select_} FROM {from_}{where}".format(
544573
select_=(
545574
"count(*)"
546-
if any(self._left)
575+
if any(result._left)
547576
else "count(DISTINCT {fields})".format(
548-
fields=self.heading.as_sql(
549-
self.primary_key, include_aliases=False
577+
fields=result.heading.as_sql(
578+
result.primary_key, include_aliases=False
550579
)
551580
)
552581
),
553-
from_=self.from_clause(),
554-
where=self.where_clause(),
582+
from_=result.from_clause(),
583+
where=result.where_clause(),
555584
)
556585
).fetchone()[0]
557586

@@ -619,18 +648,12 @@ def __next__(self):
619648
# -- move on to next entry.
620649
return next(self)
621650

622-
def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
651+
def cursor(self, as_dict=False):
623652
"""
624653
See expression.fetch() for input description.
625654
:return: query cursor
626655
"""
627-
if offset and limit is None:
628-
raise DataJointError("limit is required when offset is set")
629656
sql = self.make_sql()
630-
if order_by is not None:
631-
sql += " ORDER BY " + ", ".join(order_by)
632-
if limit is not None:
633-
sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "")
634657
logger.debug(sql)
635658
return self.connection.query(sql, as_dict=as_dict)
636659

@@ -701,23 +724,26 @@ def make_sql(self, fields=None):
701724
fields = self.heading.as_sql(fields or self.heading.names)
702725
assert self._grouping_attributes or not self.restriction
703726
distinct = set(self.heading.names) == set(self.primary_key)
704-
return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
705-
distinct="DISTINCT " if distinct else "",
706-
fields=fields,
707-
from_=self.from_clause(),
708-
where=self.where_clause(),
709-
group_by=(
710-
""
711-
if not self.primary_key
712-
else (
713-
" GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
714-
+ (
715-
""
716-
if not self.restriction
717-
else " HAVING (%s)" % ")AND(".join(self.restriction)
727+
return (
728+
"SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format(
729+
distinct="DISTINCT " if distinct else "",
730+
fields=fields,
731+
from_=self.from_clause(),
732+
where=self.where_clause(),
733+
group_by=(
734+
""
735+
if not self.primary_key
736+
else (
737+
" GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
738+
+ (
739+
""
740+
if not self.restriction
741+
else " HAVING (%s)" % ")AND(".join(self.restriction)
742+
)
718743
)
719-
)
720-
),
744+
),
745+
sorting=self.sorting_clauses(),
746+
)
721747
)
722748

723749
def __len__(self):
@@ -776,7 +802,7 @@ def make_sql(self):
776802
):
777803
# no secondary attributes: use UNION DISTINCT
778804
fields = arg1.primary_key
779-
return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format(
805+
return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format(
780806
sql1=(
781807
arg1.make_sql()
782808
if isinstance(arg1, Union)
@@ -788,6 +814,7 @@ def make_sql(self):
788814
else arg2.make_sql(fields)
789815
),
790816
alias=next(self.__count),
817+
sorting=self.sorting_clauses(),
791818
)
792819
# with secondary attributes, use union of left join with antijoin
793820
fields = self.heading.names
@@ -939,3 +966,25 @@ def aggr(self, group, **named_attributes):
939966
)
940967

941968
aggregate = aggr # alias for aggr
969+
970+
971+
def _flatten_attribute_list(primary_key, attrs):
972+
"""
973+
:param primary_key: list of attributes in primary key
974+
:param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
975+
:return: generator of attributes where "KEY" is replaced with its component attributes
976+
"""
977+
for a in attrs:
978+
if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
979+
if primary_key:
980+
yield from primary_key
981+
elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
982+
if primary_key:
983+
yield from (q + " DESC" for q in primary_key)
984+
else:
985+
yield a
986+
987+
988+
def _wrap_attributes(attr):
989+
for entry in attr: # wrap attribute names in backquotes
990+
yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE)

datajoint/fetch.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
from functools import partial
22
from pathlib import Path
3-
import logging
43
import pandas
54
import itertools
6-
import re
75
import json
86
import numpy as np
97
import uuid
108
import numbers
9+
10+
from datajoint.condition import Top
1111
from . import blob, hash
1212
from .errors import DataJointError
1313
from .settings import config
1414
from .utils import safe_write
1515

16-
logger = logging.getLogger(__name__.split(".")[0])
17-
1816

1917
class key:
2018
"""
@@ -119,21 +117,6 @@ def _get(connection, attr, data, squeeze, download_path):
119117
)
120118

121119

122-
def _flatten_attribute_list(primary_key, attrs):
123-
"""
124-
:param primary_key: list of attributes in primary key
125-
:param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
126-
:return: generator of attributes where "KEY" is replaces with its component attributes
127-
"""
128-
for a in attrs:
129-
if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
130-
yield from primary_key
131-
elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
132-
yield from (q + " DESC" for q in primary_key)
133-
else:
134-
yield a
135-
136-
137120
class Fetch:
138121
"""
139122
A fetch object that handles retrieving elements from the table expression.
@@ -174,13 +157,13 @@ def __call__(
174157
:param download_path: for fetches that download data, e.g. attachments
175158
:return: the contents of the table in the form of a structured numpy.array or a dict list
176159
"""
177-
if order_by is not None:
178-
# if 'order_by' passed in a string, make into list
179-
if isinstance(order_by, str):
180-
order_by = [order_by]
181-
# expand "KEY" or "KEY DESC"
182-
order_by = list(
183-
_flatten_attribute_list(self._expression.primary_key, order_by)
160+
if offset or order_by or limit:
161+
self._expression = self._expression.restrict(
162+
Top(
163+
limit,
164+
order_by,
165+
offset,
166+
)
184167
)
185168

186169
attrs_as_dict = as_dict and attrs
@@ -212,13 +195,6 @@ def __call__(
212195
'use "array" or "frame"'.format(format)
213196
)
214197

215-
if limit is None and offset is not None:
216-
logger.warning(
217-
"Offset set, but no limit. Setting limit to a large number. "
218-
"Consider setting a limit explicitly."
219-
)
220-
limit = 8000000000 # just a very large number to effect no limit
221-
222198
get = partial(
223199
_get,
224200
self._expression.connection,
@@ -257,9 +233,7 @@ def __call__(
257233
]
258234
ret = return_values[0] if len(attrs) == 1 else return_values
259235
else: # fetch all attributes as a numpy.record_array or pandas.DataFrame
260-
cur = self._expression.cursor(
261-
as_dict=as_dict, limit=limit, offset=offset, order_by=order_by
262-
)
236+
cur = self._expression.cursor(as_dict=as_dict)
263237
heading = self._expression.heading
264238
if as_dict:
265239
ret = [

0 commit comments

Comments
 (0)