Skip to content

Commit f773991

Browse files
Allow None to be used in dict restrictions.
1 parent 2547748 commit f773991

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

datajoint/condition.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ def __init__(self, operand):
2121

2222
class AndList(list):
2323
"""
24-
A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed.
25-
All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR).
24+
A list of conditions to by applied to a query expression by logical conjunction: the
25+
conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
26+
applied by logical disjunction (OR).
2627
2728
Example:
2829
expr2 = expr & dj.AndList((cond1, cond2, cond3))
@@ -49,14 +50,16 @@ def assert_join_compatibility(expr1, expr2):
4950
the matching attributes in the two expressions must be in the primary key of one or the
5051
other expression.
5152
Raises an exception if not compatible.
53+
5254
:param expr1: A QueryExpression object
5355
:param expr2: A QueryExpression object
5456
"""
5557
from .expression import QueryExpression, U
5658

5759
for rel in (expr1, expr2):
5860
if not isinstance(rel, (U, QueryExpression)):
59-
raise DataJointError('Object %r is not a QueryExpression and cannot be joined.' % rel)
61+
raise DataJointError(
62+
'Object %r is not a QueryExpression and cannot be joined.' % rel)
6063
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
6164
try:
6265
raise DataJointError(
@@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2):
7073
def make_condition(query_expression, condition, columns):
7174
"""
7275
Translate the input condition into the equivalent SQL condition (a string)
76+
7377
:param query_expression: a dj.QueryExpression object to apply condition
7478
:param condition: any valid restriction object.
75-
:param columns: a set passed by reference to collect all column names used in the condition.
79+
:param columns: a set passed by reference to collect all column names used in the
80+
condition.
7681
:return: an SQL condition string or a boolean value.
7782
"""
7883
from .expression import QueryExpression, Aggregation, U
@@ -102,12 +107,13 @@ def prep_value(k, v):
102107
# restrict by string
103108
if isinstance(condition, str):
104109
columns.update(extract_column_names(condition))
105-
return template % condition.strip().replace("%", "%%") # escape % in strings, see issue #376
110+
return template % condition.strip().replace("%", "%%") # escape %, see issue #376
106111

107112
# restrict by AndList
108113
if isinstance(condition, AndList):
109114
# omit all conditions that evaluate to True
110-
items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition)
115+
items = [item for item in (make_condition(query_expression, cond, columns)
116+
for cond in condition)
111117
if item is not True]
112118
if any(item is False for item in items):
113119
return negate # if any item is False, the whole thing is False
@@ -123,18 +129,21 @@ def prep_value(k, v):
123129
if isinstance(condition, bool):
124130
return negate != condition
125131

126-
# restrict by a mapping such as a dict -- convert to an AndList of string equality conditions
132+
# restrict by a mapping/dict -- convert to an AndList of string equality conditions
127133
if isinstance(condition, collections.abc.Mapping):
128134
common_attributes = set(condition).intersection(query_expression.heading.names)
129135
if not common_attributes:
130136
return not negate # no matching attributes -> evaluates to True
131137
columns.update(common_attributes)
132138
return template % ('(' + ') AND ('.join(
133-
'`%s`=%s' % (k, prep_value(k, condition[k])) for k in common_attributes) + ')')
139+
'`%s`%s' % (k, ' IS NULL' if condition[k] is None
140+
else f'={prep_value(k, condition[k])}')
141+
for k in common_attributes) + ')')
134142

135143
# restrict by a numpy record -- convert to an AndList of string equality conditions
136144
if isinstance(condition, numpy.void):
137-
common_attributes = set(condition.dtype.fields).intersection(query_expression.heading.names)
145+
common_attributes = set(condition.dtype.fields).intersection(
146+
query_expression.heading.names)
138147
if not common_attributes:
139148
return not negate # no matching attributes -> evaluate to True
140149
columns.update(common_attributes)
@@ -154,7 +163,8 @@ def prep_value(k, v):
154163
if isinstance(condition, QueryExpression):
155164
if check_compatibility:
156165
assert_join_compatibility(query_expression, condition)
157-
common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names]
166+
common_attributes = [q for q in condition.heading.names
167+
if q in query_expression.heading.names]
158168
columns.update(common_attributes)
159169
if isinstance(condition, Aggregation):
160170
condition = condition.make_subquery()
@@ -176,15 +186,17 @@ def prep_value(k, v):
176186
except TypeError:
177187
raise DataJointError('Invalid restriction type %r' % condition)
178188
else:
179-
or_list = [item for item in or_list if item is not False] # ignore all False conditions
180-
if any(item is True for item in or_list): # if any item is True, the whole thing is True
189+
or_list = [item for item in or_list if item is not False] # ignore False conditions
190+
if any(item is True for item in or_list): # if any item is True, entirely True
181191
return not negate
182-
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate # an empty or list is False
192+
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate
183193

184194

185195
def extract_column_names(sql_expression):
186196
"""
187-
extract all presumed column names from an sql expression such as the WHERE clause, for example.
197+
extract all presumed column names from an sql expression such as the WHERE clause,
198+
for example.
199+
188200
:param sql_expression: a string containing an SQL expression
189201
:return: set of extracted column names
190202
This may be MySQL-specific for now.

tests/test_relational_operand.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,3 +470,13 @@ def test_complex_date_restriction():
470470
q = OutfitLaunch & '`day` between curdate() - interval 30 day and curdate()'
471471
assert len(q) == 1
472472
q.delete()
473+
474+
@staticmethod
475+
def test_null_dict_restriction():
476+
# https://github.com/datajoint/datajoint-python/issues/824
477+
"""Test a restriction for null using dict"""
478+
F.insert([dict(id=5)])
479+
q = F & 'date is NULL'
480+
assert len(q) == 1
481+
q = F & dict(date=None)
482+
assert len(q) == 1

0 commit comments

Comments
 (0)