Skip to content

Commit d8eb1c0

Browse files
Merge pull request #893 from guzman-raphael/complex_date_restriction
Fix complex date restriction regression and various others
2 parents 909749a + 80db98d commit d8eb1c0

File tree

16 files changed

+283
-71
lines changed

16 files changed

+283
-71
lines changed

.github/workflows/development.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
matrix:
1717
py_ver: ["3.8"]
18-
mysql_ver: ["8.0", "5.7", "5.6"]
18+
mysql_ver: ["8.0", "5.7"]
1919
include:
2020
- py_ver: "3.7"
2121
mysql_ver: "5.7"

CHANGELOG.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
## Release notes
22

3+
### 0.13.1 -- TBD
4+
* Add `None` as an alias for `IS NULL` comparison in `dict` restrictions (#824) PR #893
5+
* Drop support for MySQL 5.6 since it has reached EOL PR #893
6+
* Bugfix - `schema.list_tables()` is not topologically sorted (#838) PR #893
7+
* Bugfix - Diagram part tables do not show proper class name (#882) PR #893
8+
* Bugfix - Error in complex restrictions (#892) PR #893
9+
* Bugfix - WHERE and GROUP BY clases are dropped on joins with aggregation (#898, #899) PR #893
10+
311
### 0.13.0 -- Mar 24, 2021
4-
* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484). PR #754
5-
* Re-implement cascading deletes for better performance. PR #839.
6-
* Add table method `.update1` to update a row in the table with new values PR #763
7-
* Python datatypes are now enabled by default in blobs (#761). PR #785
12+
* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484, #558). PR #754
13+
* Re-implement cascading deletes for better performance. PR #839
14+
* Add support for deferred schema activation to allow for greater modularity. (#834) PR #839
15+
* Add query caching mechanism for offline development (#550) PR #839
16+
* Add table method `.update1` to update a row in the table with new values (#867) PR #763, #889
17+
* Python datatypes are now enabled by default in blobs (#761). PR #859
818
* Added permissive join and restriction operators `@` and `^` (#785) PR #754
919
* Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735
10-
* Add `dj.key_hash` alias to `dj.hash.key_hash`
20+
* Add `dj.key_hash` alias to `dj.hash.key_hash` (#804) PR #862
1121
* Default enable_python_native_blobs to True
1222
* Bugfix - Regression error on joins with same attribute name (#857) PR #878
1323
* Bugfix - Error when `fetch1('KEY')` when `dj.config['fetch_format']='frame'` set (#876) PR #880, #878
1424
* Bugfix - Error when cascading deletes in tables with many, complex keys (#883, #886) PR #839
1525
* Add deprecation warning for `_update`. PR #889
1626
* Add `purge_query_cache` utility. PR #889
1727
* Add tests for query caching and permissive join and restriction. PR #889
18-
* Drop support for Python 3.5
28+
* Drop support for Python 3.5 (#829) PR #861
1929

2030
### 0.12.9 -- Mar 12, 2021
2131
* Fix bug with fetch1 with `dj.config['fetch_format']="frame"`. (#876) PR #880

datajoint/condition.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ def __init__(self, operand):
2121

2222
class AndList(list):
2323
"""
24-
A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed.
25-
All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR).
24+
A list of conditions to by applied to a query expression by logical conjunction: the
25+
conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
26+
applied by logical disjunction (OR).
2627
2728
Example:
2829
expr2 = expr & dj.AndList((cond1, cond2, cond3))
@@ -49,14 +50,16 @@ def assert_join_compatibility(expr1, expr2):
4950
the matching attributes in the two expressions must be in the primary key of one or the
5051
other expression.
5152
Raises an exception if not compatible.
53+
5254
:param expr1: A QueryExpression object
5355
:param expr2: A QueryExpression object
5456
"""
5557
from .expression import QueryExpression, U
5658

5759
for rel in (expr1, expr2):
5860
if not isinstance(rel, (U, QueryExpression)):
59-
raise DataJointError('Object %r is not a QueryExpression and cannot be joined.' % rel)
61+
raise DataJointError(
62+
'Object %r is not a QueryExpression and cannot be joined.' % rel)
6063
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
6164
try:
6265
raise DataJointError(
@@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2):
7073
def make_condition(query_expression, condition, columns):
7174
"""
7275
Translate the input condition into the equivalent SQL condition (a string)
76+
7377
:param query_expression: a dj.QueryExpression object to apply condition
7478
:param condition: any valid restriction object.
75-
:param columns: a set passed by reference to collect all column names used in the condition.
79+
:param columns: a set passed by reference to collect all column names used in the
80+
condition.
7681
:return: an SQL condition string or a boolean value.
7782
"""
7883
from .expression import QueryExpression, Aggregation, U
@@ -102,12 +107,13 @@ def prep_value(k, v):
102107
# restrict by string
103108
if isinstance(condition, str):
104109
columns.update(extract_column_names(condition))
105-
return template % condition.strip().replace("%", "%%") # escape % in strings, see issue #376
110+
return template % condition.strip().replace("%", "%%") # escape %, see issue #376
106111

107112
# restrict by AndList
108113
if isinstance(condition, AndList):
109114
# omit all conditions that evaluate to True
110-
items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition)
115+
items = [item for item in (make_condition(query_expression, cond, columns)
116+
for cond in condition)
111117
if item is not True]
112118
if any(item is False for item in items):
113119
return negate # if any item is False, the whole thing is False
@@ -123,18 +129,21 @@ def prep_value(k, v):
123129
if isinstance(condition, bool):
124130
return negate != condition
125131

126-
# restrict by a mapping such as a dict -- convert to an AndList of string equality conditions
132+
# restrict by a mapping/dict -- convert to an AndList of string equality conditions
127133
if isinstance(condition, collections.abc.Mapping):
128134
common_attributes = set(condition).intersection(query_expression.heading.names)
129135
if not common_attributes:
130136
return not negate # no matching attributes -> evaluates to True
131137
columns.update(common_attributes)
132138
return template % ('(' + ') AND ('.join(
133-
'`%s`=%s' % (k, prep_value(k, condition[k])) for k in common_attributes) + ')')
139+
'`%s`%s' % (k, ' IS NULL' if condition[k] is None
140+
else f'={prep_value(k, condition[k])}')
141+
for k in common_attributes) + ')')
134142

135143
# restrict by a numpy record -- convert to an AndList of string equality conditions
136144
if isinstance(condition, numpy.void):
137-
common_attributes = set(condition.dtype.fields).intersection(query_expression.heading.names)
145+
common_attributes = set(condition.dtype.fields).intersection(
146+
query_expression.heading.names)
138147
if not common_attributes:
139148
return not negate # no matching attributes -> evaluate to True
140149
columns.update(common_attributes)
@@ -154,7 +163,8 @@ def prep_value(k, v):
154163
if isinstance(condition, QueryExpression):
155164
if check_compatibility:
156165
assert_join_compatibility(query_expression, condition)
157-
common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names]
166+
common_attributes = [q for q in condition.heading.names
167+
if q in query_expression.heading.names]
158168
columns.update(common_attributes)
159169
if isinstance(condition, Aggregation):
160170
condition = condition.make_subquery()
@@ -176,15 +186,17 @@ def prep_value(k, v):
176186
except TypeError:
177187
raise DataJointError('Invalid restriction type %r' % condition)
178188
else:
179-
or_list = [item for item in or_list if item is not False] # ignore all False conditions
180-
if any(item is True for item in or_list): # if any item is True, the whole thing is True
189+
or_list = [item for item in or_list if item is not False] # ignore False conditions
190+
if any(item is True for item in or_list): # if any item is True, entirely True
181191
return not negate
182-
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate # an empty or list is False
192+
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate
183193

184194

185195
def extract_column_names(sql_expression):
186196
"""
187-
extract all presumed column names from an sql expression such as the WHERE clause, for example.
197+
extract all presumed column names from an sql expression such as the WHERE clause,
198+
for example.
199+
188200
:param sql_expression: a string containing an SQL expression
189201
:return: set of extracted column names
190202
This may be MySQL-specific for now.
@@ -206,5 +218,8 @@ def extract_column_names(sql_expression):
206218
s = re.sub(r"(\b[a-z][a-z_0-9]*)\(", "(", s)
207219
remaining_tokens = set(re.findall(r"\b[a-z][a-z_0-9]*\b", s))
208220
# update result removing reserved words
209-
result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null", "not"})
221+
result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null",
222+
"not", "interval", "second", "minute", "hour", "day",
223+
"month", "week", "year"
224+
})
210225
return result

datajoint/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ def connect(self):
203203
self._conn = client.connect(
204204
init_command=self.init_fun,
205205
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
206-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
206+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
207207
charset=config['connection.charset'],
208208
**{k: v for k, v in self.conn_info.items()
209209
if k not in ['ssl_input', 'host_input']})
210210
except client.err.InternalError:
211211
self._conn = client.connect(
212212
init_command=self.init_fun,
213213
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
214-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
214+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
215215
charset=config['connection.charset'],
216216
**{k: v for k, v in self.conn_info.items()
217217
if not(k in ['ssl_input', 'host_input'] or

datajoint/diagram.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,24 +219,32 @@ def _make_graph(self):
219219
"""
220220
Make the self.graph - a graph object ready for drawing
221221
"""
222-
# mark "distinguished" tables, i.e. those that introduce new primary key attributes
222+
# mark "distinguished" tables, i.e. those that introduce new primary key
223+
# attributes
223224
for name in self.nodes_to_show:
224225
foreign_attributes = set(
225-
attr for p in self.in_edges(name, data=True) for attr in p[2]['attr_map'] if p[2]['primary'])
226+
attr for p in self.in_edges(name, data=True)
227+
for attr in p[2]['attr_map'] if p[2]['primary'])
226228
self.nodes[name]['distinguished'] = (
227-
'primary_key' in self.nodes[name] and foreign_attributes < self.nodes[name]['primary_key'])
229+
'primary_key' in self.nodes[name] and
230+
foreign_attributes < self.nodes[name]['primary_key'])
228231
# include aliased nodes that are sandwiched between two displayed nodes
229-
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
230-
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show))
232+
gaps = set(nx.algorithms.boundary.node_boundary(
233+
self, self.nodes_to_show)).intersection(
234+
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(),
235+
self.nodes_to_show))
231236
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
232237
# construct subgraph and rename nodes to class names
233238
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
234-
nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) for n in graph})
239+
nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n)
240+
for n in graph})
235241
# relabel nodes to class names
236-
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
242+
mapping = {node: lookup_class_name(node, self.context) or node
243+
for node in graph.nodes()}
237244
new_names = [mapping.values()]
238245
if len(new_names) > len(set(new_names)):
239-
raise DataJointError('Some classes have identical names. The Diagram cannot be plotted.')
246+
raise DataJointError(
247+
'Some classes have identical names. The Diagram cannot be plotted.')
240248
nx.relabel_nodes(graph, mapping, copy=False)
241249
return graph
242250

datajoint/expression.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ def restriction_attributes(self):
8484
def primary_key(self):
8585
return self.heading.primary_key
8686

87-
_subquery_alias_count = count() # count for alias names used in from_clause
87+
_subquery_alias_count = count() # count for alias names used in the FROM clause
8888

8989
def from_clause(self):
90-
support = ('(' + src.make_sql() + ') as `_s%x`' % next(
91-
self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support)
90+
support = ('(' + src.make_sql() + ') as `$%x`' % next(
91+
self._subquery_alias_count) if isinstance(src, QueryExpression)
92+
else src for src in self.support)
9293
clause = next(support)
9394
for s, left in zip(support, self._left):
94-
clause += 'NATURAL{left} JOIN {clause}'.format(
95+
clause += ' NATURAL{left} JOIN {clause}'.format(
9596
left=" LEFT" if left else "",
9697
clause=s)
9798
return clause
@@ -264,8 +265,10 @@ def join(self, other, semantic_check=True, left=False):
264265
(set(self.original_heading.names) & set(other.original_heading.names))
265266
- join_attributes)
266267
# need subquery if any of the join attributes are derived
267-
need_subquery1 = need_subquery1 or any(n in self.heading.new_attributes for n in join_attributes)
268-
need_subquery2 = need_subquery2 or any(n in other.heading.new_attributes for n in join_attributes)
268+
need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or
269+
any(n in self.heading.new_attributes for n in join_attributes))
270+
need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or
271+
any(n in other.heading.new_attributes for n in join_attributes))
269272
if need_subquery1:
270273
self = self.make_subquery()
271274
if need_subquery2:
@@ -721,8 +724,9 @@ def __and__(self, other):
721724

722725
def join(self, other, left=False):
723726
"""
724-
Joining U with a query expression has the effect of promoting the attributes of U to the primary key of
725-
the other query expression.
727+
Joining U with a query expression has the effect of promoting the attributes of U to
728+
the primary key of the other query expression.
729+
726730
:param other: the other query expression to join with.
727731
:param left: ignored. dj.U always acts as if left=False
728732
:return: a copy of the other query expression with the primary key extended.
@@ -733,12 +737,14 @@ def join(self, other, left=False):
733737
raise DataJointError('Set U can only be joined with a QueryExpression.')
734738
try:
735739
raise DataJointError(
736-
'Attribute `%s` not found' % next(k for k in self.primary_key if k not in other.heading.names))
740+
'Attribute `%s` not found' % next(k for k in self.primary_key
741+
if k not in other.heading.names))
737742
except StopIteration:
738743
pass # all ok
739744
result = copy.copy(other)
740745
result._heading = result.heading.set_primary_key(
741-
other.primary_key + [k for k in self.primary_key if k not in other.primary_key])
746+
other.primary_key + [k for k in self.primary_key
747+
if k not in other.primary_key])
742748
return result
743749

744750
def __mul__(self, other):

datajoint/schemas.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def ordered_dir(class_):
2929
"""
3030
attr_list = list()
3131
for c in reversed(class_.mro()):
32-
attr_list.extend(e for e in (
33-
c._ordered_class_members if hasattr(c, '_ordered_class_members') else c.__dict__)
34-
if e not in attr_list)
32+
attr_list.extend(e for e in c.__dict__ if e not in attr_list)
3533
return attr_list
3634

3735

@@ -374,9 +372,9 @@ def list_tables(self):
374372
as ~logs and ~job
375373
:return: A list of table names from the database schema.
376374
"""
377-
return [table_name for (table_name,) in self.connection.query("""
378-
SELECT table_name FROM information_schema.tables
379-
WHERE table_schema = %s and table_name NOT LIKE '~%%'""", args=(self.database,))]
375+
return [t for d, t in (full_t.replace('`', '').split('.')
376+
for full_t in Diagram(self).topological_sort())
377+
if d == self.database]
380378

381379

382380
class VirtualModule(types.ModuleType):

datajoint/table.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -720,14 +720,18 @@ def lookup_class_name(name, context, depth=3):
720720
if member.full_table_name == name: # found it!
721721
return '.'.join([node['context_name'], member_name]).lstrip('.')
722722
try: # look for part tables
723-
parts = member._ordered_class_members
723+
parts = member.__dict__
724724
except AttributeError:
725725
pass # not a UserTable -- cannot have part tables.
726726
else:
727-
for part in (getattr(member, p) for p in parts if p[0].isupper() and hasattr(member, p)):
728-
if inspect.isclass(part) and issubclass(part, Table) and part.full_table_name == name:
729-
return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.')
730-
elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint':
727+
for part in (getattr(member, p) for p in parts
728+
if p[0].isupper() and hasattr(member, p)):
729+
if inspect.isclass(part) and issubclass(part, Table) and \
730+
part.full_table_name == name:
731+
return '.'.join([node['context_name'],
732+
member_name, part.__name__]).lstrip('.')
733+
elif node['depth'] > 0 and inspect.ismodule(member) and \
734+
member.__name__ != 'datajoint':
731735
try:
732736
nodes.append(
733737
dict(context=dict(inspect.getmembers(member)),

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.13.0"
1+
__version__ = "0.13.1"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

0 commit comments

Comments
 (0)