9
9
from .preview import preview , repr_html
10
10
from .condition import (
11
11
AndList ,
12
+ Top ,
12
13
Not ,
13
14
make_condition ,
14
15
assert_join_compatibility ,
@@ -52,6 +53,7 @@ class QueryExpression:
52
53
_connection = None
53
54
_heading = None
54
55
_support = None
56
+ _top = None
55
57
56
58
# If the query will be using distinct
57
59
_distinct = False
@@ -121,17 +123,33 @@ def where_clause(self):
121
123
else " WHERE (%s)" % ")AND(" .join (str (s ) for s in self .restriction )
122
124
)
123
125
126
+ def sorting_clauses (self ):
127
+ if not self ._top :
128
+ return ""
129
+ clause = ", " .join (
130
+ _wrap_attributes (
131
+ _flatten_attribute_list (self .primary_key , self ._top .order_by )
132
+ )
133
+ )
134
+ if clause :
135
+ clause = f" ORDER BY { clause } "
136
+ if self ._top .limit is not None :
137
+ clause += f" LIMIT { self ._top .limit } { f' OFFSET { self ._top .offset } ' if self ._top .offset else '' } "
138
+
139
+ return clause
140
+
124
141
def make_sql (self , fields = None ):
125
142
"""
126
143
Make the SQL SELECT statement.
127
144
128
145
:param fields: used to explicitly set the select attributes
129
146
"""
130
- return "SELECT {distinct}{fields} FROM {from_}{where}" .format (
147
+ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting} " .format (
131
148
distinct = "DISTINCT " if self ._distinct else "" ,
132
149
fields = self .heading .as_sql (fields or self .heading .names ),
133
150
from_ = self .from_clause (),
134
151
where = self .where_clause (),
152
+ sorting = self .sorting_clauses (),
135
153
)
136
154
137
155
# --------- query operators -----------
@@ -189,6 +207,14 @@ def restrict(self, restriction):
189
207
string, or an AndList.
190
208
"""
191
209
attributes = set ()
210
+ if isinstance (restriction , Top ):
211
+ result = (
212
+ self .make_subquery ()
213
+ if self ._top and not self ._top .__eq__ (restriction )
214
+ else copy .copy (self )
215
+ ) # make subquery to avoid overwriting existing Top
216
+ result ._top = restriction
217
+ return result
192
218
new_condition = make_condition (self , restriction , attributes )
193
219
if new_condition is True :
194
220
return self # restriction has no effect, return the same object
@@ -202,8 +228,10 @@ def restrict(self, restriction):
202
228
pass # all ok
203
229
# If the new condition uses any new attributes, a subquery is required.
204
230
# However, Aggregation's HAVING statement works fine with aliased attributes.
205
- need_subquery = isinstance (self , Union ) or (
206
- not isinstance (self , Aggregation ) and self .heading .new_attributes
231
+ need_subquery = (
232
+ isinstance (self , Union )
233
+ or (not isinstance (self , Aggregation ) and self .heading .new_attributes )
234
+ or self ._top
207
235
)
208
236
if need_subquery :
209
237
result = self .make_subquery ()
@@ -539,19 +567,20 @@ def tail(self, limit=25, **fetch_kwargs):
539
567
540
568
def __len__ (self ):
541
569
""":return: number of elements in the result set e.g. ``len(q1)``."""
542
- return self .connection .query (
570
+ result = self .make_subquery () if self ._top else copy .copy (self )
571
+ return result .connection .query (
543
572
"SELECT {select_} FROM {from_}{where}" .format (
544
573
select_ = (
545
574
"count(*)"
546
- if any (self ._left )
575
+ if any (result ._left )
547
576
else "count(DISTINCT {fields})" .format (
548
- fields = self .heading .as_sql (
549
- self .primary_key , include_aliases = False
577
+ fields = result .heading .as_sql (
578
+ result .primary_key , include_aliases = False
550
579
)
551
580
)
552
581
),
553
- from_ = self .from_clause (),
554
- where = self .where_clause (),
582
+ from_ = result .from_clause (),
583
+ where = result .where_clause (),
555
584
)
556
585
).fetchone ()[0 ]
557
586
@@ -619,18 +648,12 @@ def __next__(self):
619
648
# -- move on to next entry.
620
649
return next (self )
621
650
622
- def cursor (self , offset = 0 , limit = None , order_by = None , as_dict = False ):
651
+ def cursor (self , as_dict = False ):
623
652
"""
624
653
See expression.fetch() for input description.
625
654
:return: query cursor
626
655
"""
627
- if offset and limit is None :
628
- raise DataJointError ("limit is required when offset is set" )
629
656
sql = self .make_sql ()
630
- if order_by is not None :
631
- sql += " ORDER BY " + ", " .join (order_by )
632
- if limit is not None :
633
- sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "" )
634
657
logger .debug (sql )
635
658
return self .connection .query (sql , as_dict = as_dict )
636
659
@@ -701,23 +724,26 @@ def make_sql(self, fields=None):
701
724
fields = self .heading .as_sql (fields or self .heading .names )
702
725
assert self ._grouping_attributes or not self .restriction
703
726
distinct = set (self .heading .names ) == set (self .primary_key )
704
- return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}" .format (
705
- distinct = "DISTINCT " if distinct else "" ,
706
- fields = fields ,
707
- from_ = self .from_clause (),
708
- where = self .where_clause (),
709
- group_by = (
710
- ""
711
- if not self .primary_key
712
- else (
713
- " GROUP BY `%s`" % "`,`" .join (self ._grouping_attributes )
714
- + (
715
- ""
716
- if not self .restriction
717
- else " HAVING (%s)" % ")AND(" .join (self .restriction )
727
+ return (
728
+ "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}" .format (
729
+ distinct = "DISTINCT " if distinct else "" ,
730
+ fields = fields ,
731
+ from_ = self .from_clause (),
732
+ where = self .where_clause (),
733
+ group_by = (
734
+ ""
735
+ if not self .primary_key
736
+ else (
737
+ " GROUP BY `%s`" % "`,`" .join (self ._grouping_attributes )
738
+ + (
739
+ ""
740
+ if not self .restriction
741
+ else " HAVING (%s)" % ")AND(" .join (self .restriction )
742
+ )
718
743
)
719
- )
720
- ),
744
+ ),
745
+ sorting = self .sorting_clauses (),
746
+ )
721
747
)
722
748
723
749
def __len__ (self ):
@@ -776,7 +802,7 @@ def make_sql(self):
776
802
):
777
803
# no secondary attributes: use UNION DISTINCT
778
804
fields = arg1 .primary_key
779
- return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`" .format (
805
+ return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting} `" .format (
780
806
sql1 = (
781
807
arg1 .make_sql ()
782
808
if isinstance (arg1 , Union )
@@ -788,6 +814,7 @@ def make_sql(self):
788
814
else arg2 .make_sql (fields )
789
815
),
790
816
alias = next (self .__count ),
817
+ sorting = self .sorting_clauses (),
791
818
)
792
819
# with secondary attributes, use union of left join with antijoin
793
820
fields = self .heading .names
@@ -939,3 +966,25 @@ def aggr(self, group, **named_attributes):
939
966
)
940
967
941
968
aggregate = aggr # alias for aggr
969
+
970
+
971
+ def _flatten_attribute_list (primary_key , attrs ):
972
+ """
973
+ :param primary_key: list of attributes in primary key
974
+ :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
975
+ :return: generator of attributes where "KEY" is replaced with its component attributes
976
+ """
977
+ for a in attrs :
978
+ if re .match (r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$" , a ):
979
+ if primary_key :
980
+ yield from primary_key
981
+ elif re .match (r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$" , a ):
982
+ if primary_key :
983
+ yield from (q + " DESC" for q in primary_key )
984
+ else :
985
+ yield a
986
+
987
+
988
+ def _wrap_attributes (attr ):
989
+ for entry in attr : # wrap attribute names in backquotes
990
+ yield re .sub (r"\b((?!asc|desc)\w+)\b" , r"`\1`" , entry , flags = re .IGNORECASE )
0 commit comments