Skip to content

Commit 06e441c

Browse files
Merge pull request #870 from dimitri-yatsenko/cascade-delete
Cascade delete
2 parents c14eb73 + eb72758 commit 06e441c

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

datajoint/fetch.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,26 +149,30 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None,
149149
attrs = list(self._expression.primary_key) + [
150150
a for a in attrs if a not in self._expression.primary_key]
151151
if as_dict is None:
152-
as_dict = bool(attrs) # default to True for "KEY" and False when fetching entire result
152+
as_dict = bool(attrs) # default to True for "KEY" and False otherwise
153153
# format should not be specified with attrs or is_dict=True
154154
if format is not None and (as_dict or attrs):
155155
raise DataJointError('Cannot specify output format when as_dict=True or '
156156
'when attributes are selected to be fetched separately.')
157157
if format not in {None, "array", "frame"}:
158-
raise DataJointError('Fetch output format must be in {{"array", "frame"}} but "{}" was given'.format(format))
158+
raise DataJointError(
159+
'Fetch output format must be in '
160+
'{{"array", "frame"}} but "{}" was given'.format(format))
159161

160162
if not (attrs or as_dict) and format is None:
161163
format = config['fetch_format'] # default to array
162164
if format not in {"array", "frame"}:
163-
raise DataJointError('Invalid entry "{}" in datajoint.config["fetch_format"]: use "array" or "frame"'.format(
164-
format))
165+
raise DataJointError(
166+
'Invalid entry "{}" in datajoint.config["fetch_format"]: '
167+
'use "array" or "frame"'.format(format))
165168

166169
if limit is None and offset is not None:
167170
warnings.warn('Offset set, but no limit. Setting limit to a large number. '
168171
'Consider setting a limit explicitly.')
169172
limit = 8000000000 # just a very large number to effect no limit
170173

171-
get = partial(_get, self._expression.connection, squeeze=squeeze, download_path=download_path)
174+
get = partial(_get, self._expression.connection,
175+
squeeze=squeeze, download_path=download_path)
172176
if attrs: # a list of attributes provided
173177
attributes = [a for a in attrs if not is_key(a)]
174178
ret = self._expression.proj(*attributes)
@@ -179,19 +183,22 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None,
179183
if attrs_as_dict:
180184
ret = [{k: v for k, v in zip(ret.dtype.names, x) if k in attrs} for x in ret]
181185
else:
182-
return_values = [
183-
list((to_dicts if as_dict else lambda x: x)(ret[self._expression.primary_key])) if is_key(attribute)
184-
else ret[attribute] for attribute in attrs]
186+
return_values = [list(
187+
(to_dicts if as_dict else lambda x: x)(ret[self._expression.primary_key]))
188+
if is_key(attribute) else ret[attribute]
189+
for attribute in attrs]
185190
ret = return_values[0] if len(attrs) == 1 else return_values
186191
else: # fetch all attributes as a numpy.record_array or pandas.DataFrame
187-
cur = self._expression.cursor(as_dict=as_dict, limit=limit, offset=offset, order_by=order_by)
192+
cur = self._expression.cursor(
193+
as_dict=as_dict, limit=limit, offset=offset, order_by=order_by)
188194
heading = self._expression.heading
189195
if as_dict:
190-
ret = [dict((name, get(heading[name], d[name])) for name in heading.names) for d in cur]
196+
ret = [dict((name, get(heading[name], d[name]))
197+
for name in heading.names) for d in cur]
191198
else:
192199
ret = list(cur.fetchall())
193200
record_type = (heading.as_dtype if not ret else np.dtype(
194-
[(name, type(value)) # use the first element to determine the type for blobs
201+
[(name, type(value)) # use the first element to determine blob type
195202
if heading[name].is_blob and isinstance(value, numbers.Number)
196203
else (name, heading.as_dtype[name])
197204
for value, name in zip(ret[0], heading.as_dtype.names)]))
@@ -208,15 +215,15 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None,
208215

209216
class Fetch1:
210217
"""
211-
Fetch object for fetching exactly one row.
212-
:param relation: relation the fetch object fetches data from
218+
Fetch object for fetching the result of a query yielding one row.
219+
:param expression: a query expression to fetch from.
213220
"""
214-
def __init__(self, relation):
215-
self._expression = relation
221+
def __init__(self, expression):
222+
self._expression = expression
216223

217224
def __call__(self, *attrs, squeeze=False, download_path='.'):
218225
"""
219-
Fetches the expression results from the database when the expression is known to yield only one entry.
226+
Fetches the result of a query expression that yields one entry.
220227
221228
If no attributes are specified, returns the result as a dict.
222229
If attributes are specified returns the corresponding results as a tuple.
@@ -225,7 +232,8 @@ def __call__(self, *attrs, squeeze=False, download_path='.'):
225232
d = rel.fetch1() # as a dictionary
226233
a, b = rel.fetch1('a', 'b') # as a tuple
227234
228-
:params *attrs: attributes to return when expanding into a tuple. If empty, the return result is a dict
235+
:params *attrs: attributes to return when expanding into a tuple.
236+
If attrs is empty, the return result is a dict
229237
:param squeeze: When true, remove extra dimensions from arrays in attributes
230238
:param download_path: for fetches that download data, e.g. attachments
231239
:return: the one tuple in the relation in the form of a dict
@@ -236,17 +244,20 @@ def __call__(self, *attrs, squeeze=False, download_path='.'):
236244
cur = self._expression.cursor(as_dict=True)
237245
ret = cur.fetchone()
238246
if not ret or cur.fetchone():
239-
raise DataJointError('fetch1 should only be used for relations with exactly one tuple')
247+
raise DataJointError('fetch1 requires exactly one tuple in the input set.')
240248
ret = dict((name, _get(self._expression.connection, heading[name], ret[name],
241249
squeeze=squeeze, download_path=download_path))
242250
for name in heading.names)
243251
else: # fetch some attributes, return as tuple
244252
attributes = [a for a in attrs if not is_key(a)]
245-
result = self._expression.proj(*attributes).fetch(squeeze=squeeze, download_path=download_path)
253+
result = self._expression.proj(*attributes).fetch(
254+
squeeze=squeeze, download_path=download_path)
246255
if len(result) != 1:
247-
raise DataJointError('fetch1 should only return one tuple. %d tuples were found' % len(result))
256+
raise DataJointError(
257+
'fetch1 should only return one tuple. %d tuples found' % len(result))
248258
return_values = tuple(
249-
next(to_dicts(result[self._expression.primary_key])) if is_key(attribute) else result[attribute][0]
259+
next(to_dicts(result[self._expression.primary_key]))
260+
if is_key(attribute) else result[attribute][0]
250261
for attribute in attrs)
251262
ret = return_values[0] if len(attrs) == 1 else return_values
252263
return ret

datajoint/table.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def delete(self, transaction=True, safemode=None):
378378
:param transaction: if True, use the entire delete becomes an atomic transaction.
379379
:param safemode: If True, prohibit nested transactions and prompt to confirm. Default is dj.config['safemode'].
380380
"""
381-
safemode = safemode or config['safemode']
381+
safemode = config['safemode'] if safemode is None else safemode
382382

383383
# Start transaction
384384
if transaction:
@@ -408,11 +408,13 @@ def delete(self, transaction=True, safemode=None):
408408
self.connection.cancel_transaction()
409409
else:
410410
if not safemode or user_choice("Commit deletes?", default='no') == 'yes':
411-
self.connection.commit_transaction()
411+
if transaction:
412+
self.connection.commit_transaction()
412413
if safemode:
413414
print('Deletes committed.')
414415
else:
415-
self.connection.cancel_transaction()
416+
if transaction:
417+
self.connection.cancel_transaction()
416418
if safemode:
417419
print('Deletes cancelled')
418420

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.13.dev3"
1+
__version__ = "0.13.dev4"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

0 commit comments

Comments
 (0)