@@ -21,8 +21,9 @@ def __init__(self, operand):
21
21
22
22
class AndList (list ):
23
23
"""
24
- A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed.
25
- All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR).
24
+ A list of conditions to by applied to a query expression by logical conjunction: the
25
+ conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
26
+ applied by logical disjunction (OR).
26
27
27
28
Example:
28
29
expr2 = expr & dj.AndList((cond1, cond2, cond3))
@@ -49,14 +50,16 @@ def assert_join_compatibility(expr1, expr2):
49
50
the matching attributes in the two expressions must be in the primary key of one or the
50
51
other expression.
51
52
Raises an exception if not compatible.
53
+
52
54
:param expr1: A QueryExpression object
53
55
:param expr2: A QueryExpression object
54
56
"""
55
57
from .expression import QueryExpression , U
56
58
57
59
for rel in (expr1 , expr2 ):
58
60
if not isinstance (rel , (U , QueryExpression )):
59
- raise DataJointError ('Object %r is not a QueryExpression and cannot be joined.' % rel )
61
+ raise DataJointError (
62
+ 'Object %r is not a QueryExpression and cannot be joined.' % rel )
60
63
if not isinstance (expr1 , U ) and not isinstance (expr2 , U ): # dj.U is always compatible
61
64
try :
62
65
raise DataJointError (
@@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2):
70
73
def make_condition (query_expression , condition , columns ):
71
74
"""
72
75
Translate the input condition into the equivalent SQL condition (a string)
76
+
73
77
:param query_expression: a dj.QueryExpression object to apply condition
74
78
:param condition: any valid restriction object.
75
- :param columns: a set passed by reference to collect all column names used in the condition.
79
+ :param columns: a set passed by reference to collect all column names used in the
80
+ condition.
76
81
:return: an SQL condition string or a boolean value.
77
82
"""
78
83
from .expression import QueryExpression , Aggregation , U
@@ -102,12 +107,13 @@ def prep_value(k, v):
102
107
# restrict by string
103
108
if isinstance (condition , str ):
104
109
columns .update (extract_column_names (condition ))
105
- return template % condition .strip ().replace ("%" , "%%" ) # escape % in strings , see issue #376
110
+ return template % condition .strip ().replace ("%" , "%%" ) # escape %, see issue #376
106
111
107
112
# restrict by AndList
108
113
if isinstance (condition , AndList ):
109
114
# omit all conditions that evaluate to True
110
- items = [item for item in (make_condition (query_expression , cond , columns ) for cond in condition )
115
+ items = [item for item in (make_condition (query_expression , cond , columns )
116
+ for cond in condition )
111
117
if item is not True ]
112
118
if any (item is False for item in items ):
113
119
return negate # if any item is False, the whole thing is False
@@ -123,18 +129,21 @@ def prep_value(k, v):
123
129
if isinstance (condition , bool ):
124
130
return negate != condition
125
131
126
- # restrict by a mapping such as a dict -- convert to an AndList of string equality conditions
132
+ # restrict by a mapping/ dict -- convert to an AndList of string equality conditions
127
133
if isinstance (condition , collections .abc .Mapping ):
128
134
common_attributes = set (condition ).intersection (query_expression .heading .names )
129
135
if not common_attributes :
130
136
return not negate # no matching attributes -> evaluates to True
131
137
columns .update (common_attributes )
132
138
return template % ('(' + ') AND (' .join (
133
- '`%s`=%s' % (k , prep_value (k , condition [k ])) for k in common_attributes ) + ')' )
139
+ '`%s`%s' % (k , ' IS NULL' if condition [k ] is None
140
+ else f'={ prep_value (k , condition [k ])} ' )
141
+ for k in common_attributes ) + ')' )
134
142
135
143
# restrict by a numpy record -- convert to an AndList of string equality conditions
136
144
if isinstance (condition , numpy .void ):
137
- common_attributes = set (condition .dtype .fields ).intersection (query_expression .heading .names )
145
+ common_attributes = set (condition .dtype .fields ).intersection (
146
+ query_expression .heading .names )
138
147
if not common_attributes :
139
148
return not negate # no matching attributes -> evaluate to True
140
149
columns .update (common_attributes )
@@ -154,7 +163,8 @@ def prep_value(k, v):
154
163
if isinstance (condition , QueryExpression ):
155
164
if check_compatibility :
156
165
assert_join_compatibility (query_expression , condition )
157
- common_attributes = [q for q in condition .heading .names if q in query_expression .heading .names ]
166
+ common_attributes = [q for q in condition .heading .names
167
+ if q in query_expression .heading .names ]
158
168
columns .update (common_attributes )
159
169
if isinstance (condition , Aggregation ):
160
170
condition = condition .make_subquery ()
@@ -176,15 +186,17 @@ def prep_value(k, v):
176
186
except TypeError :
177
187
raise DataJointError ('Invalid restriction type %r' % condition )
178
188
else :
179
- or_list = [item for item in or_list if item is not False ] # ignore all False conditions
180
- if any (item is True for item in or_list ): # if any item is True, the whole thing is True
189
+ or_list = [item for item in or_list if item is not False ] # ignore False conditions
190
+ if any (item is True for item in or_list ): # if any item is True, entirely True
181
191
return not negate
182
- return template % ('(%s)' % ' OR ' .join (or_list )) if or_list else negate # an empty or list is False
192
+ return template % ('(%s)' % ' OR ' .join (or_list )) if or_list else negate
183
193
184
194
185
195
def extract_column_names (sql_expression ):
186
196
"""
187
- extract all presumed column names from an sql expression such as the WHERE clause, for example.
197
+ extract all presumed column names from an sql expression such as the WHERE clause,
198
+ for example.
199
+
188
200
:param sql_expression: a string containing an SQL expression
189
201
:return: set of extracted column names
190
202
This may be MySQL-specific for now.
@@ -206,5 +218,8 @@ def extract_column_names(sql_expression):
206
218
s = re .sub (r"(\b[a-z][a-z_0-9]*)\(" , "(" , s )
207
219
remaining_tokens = set (re .findall (r"\b[a-z][a-z_0-9]*\b" , s ))
208
220
# update result removing reserved words
209
- result .update (remaining_tokens - {"is" , "in" , "between" , "like" , "and" , "or" , "null" , "not" })
221
+ result .update (remaining_tokens - {"is" , "in" , "between" , "like" , "and" , "or" , "null" ,
222
+ "not" , "interval" , "second" , "minute" , "hour" , "day" ,
223
+ "month" , "week" , "year"
224
+ })
210
225
return result
0 commit comments