Skip to content

Commit 3d8c6c3

Browse files
Resolve merge conflicts.
2 parents d542fe1 + 304295e commit 3d8c6c3

File tree

6 files changed

+258
-112
lines changed

6 files changed

+258
-112
lines changed

datajoint/table.py

Lines changed: 139 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,40 @@ def _log(self):
205205
def external(self):
206206
return self.connection.schemas[self.database].external
207207

208+
def update1(self, row):
209+
"""
210+
Update an existing entry in the table.
211+
Caution: Updates are not part of the DataJoint data manipulation model. For strict data integrity,
212+
use delete and insert.
213+
:param row: a dict containing the primary key and the attributes to update.
214+
Setting an attribute value to None will reset it to the default value (if any)
215+
The primary key attributes must always be provided.
216+
Examples:
217+
>>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1
218+
>>> table.update1({'id': 1, 'value': None}) # reset value to default
219+
"""
220+
# argument validations
221+
if not isinstance(row, collections.abc.Mapping):
222+
raise DataJointError('The argument of update1 must be dict-like.')
223+
if not set(row).issuperset(self.primary_key):
224+
raise DataJointError('The argument of update1 must supply all primary key values.')
225+
try:
226+
raise DataJointError('Attribute `%s` not found.' % next(k for k in row if k not in self.heading.names))
227+
except StopIteration:
228+
pass # ok
229+
if len(self.restriction):
230+
raise DataJointError('Update cannot be applied to a restricted table.')
231+
key = {k: row[k] for k in self.primary_key}
232+
if len(self & key) != 1:
233+
raise DataJointError('Update entry must exist.')
234+
# UPDATE query
235+
row = [self.__make_placeholder(k, v) for k, v in row.items() if k not in self.primary_key]
236+
query = "UPDATE {table} SET {assignments} WHERE {where}".format(
237+
table=self.full_table_name,
238+
assignments=",".join('`%s`=%s' % r[:2] for r in row),
239+
where=self._make_condition(key))
240+
self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))
241+
208242
def insert1(self, row, **kwargs):
209243
"""
210244
Insert one data record or one Mapping (like a dict).
@@ -244,7 +278,6 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
244278
'Inserts into an auto-populated table can only done inside its make method during a populate call.'
245279
' To override, set keyword argument allow_direct_insert=True.')
246280

247-
heading = self.heading
248281
if inspect.isclass(rows) and issubclass(rows, QueryExpression): # instantiate if a class
249282
rows = rows()
250283
if isinstance(rows, QueryExpression):
@@ -253,10 +286,10 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
253286
try:
254287
raise DataJointError(
255288
"Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." %
256-
next(name for name in rows.heading if name not in heading))
289+
next(name for name in rows.heading if name not in self.heading))
257290
except StopIteration:
258291
pass
259-
fields = list(name for name in rows.heading if name in heading)
292+
fields = list(name for name in rows.heading if name in self.heading)
260293
query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format(
261294
command='REPLACE' if replace else 'INSERT',
262295
fields='`' + '`,`'.join(fields) + '`',
@@ -268,110 +301,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
268301
self.connection.query(query)
269302
return
270303

271-
if heading.attributes is None:
272-
logger.warning('Could not access table {table}'.format(table=self.full_table_name))
273-
return
274-
275-
field_list = None # ensures that all rows have the same attributes in the same order as the first row.
276-
277-
def make_row_to_insert(row):
278-
"""
279-
:param row: A tuple to insert
280-
:return: a dict with fields 'names', 'placeholders', 'values'
281-
"""
282-
def make_placeholder(name, value):
283-
"""
284-
For a given attribute `name` with `value`, return its processed value or value placeholder
285-
as a string to be included in the query and the value, if any, to be submitted for
286-
processing by mysql API.
287-
:param name: name of attribute to be inserted
288-
:param value: value of attribute to be inserted
289-
"""
290-
if ignore_extra_fields and name not in heading:
291-
return None
292-
attr = heading[name]
293-
if attr.adapter:
294-
value = attr.adapter.put(value)
295-
if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))):
296-
# set default value
297-
placeholder, value = 'DEFAULT', None
298-
else: # not NULL
299-
placeholder = '%s'
300-
if attr.uuid:
301-
if not isinstance(value, uuid.UUID):
302-
try:
303-
value = uuid.UUID(value)
304-
except (AttributeError, ValueError):
305-
raise DataJointError(
306-
'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name))
307-
value = value.bytes
308-
elif attr.is_blob:
309-
value = blob.pack(value)
310-
value = self.external[attr.store].put(value).bytes if attr.is_external else value
311-
elif attr.is_attachment:
312-
attachment_path = Path(value)
313-
if attr.is_external:
314-
# value is hash of contents
315-
value = self.external[attr.store].upload_attachment(attachment_path).bytes
316-
else:
317-
# value is filename + contents
318-
value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes()
319-
elif attr.is_filepath:
320-
value = self.external[attr.store].upload_filepath(value).bytes
321-
elif attr.numeric:
322-
value = str(int(value) if isinstance(value, bool) else value)
323-
return name, placeholder, value
324-
325-
def check_fields(fields):
326-
"""
327-
Validates that all items in `fields` are valid attributes in the heading
328-
:param fields: field names of a tuple
329-
"""
330-
if field_list is None:
331-
if not ignore_extra_fields:
332-
for field in fields:
333-
if field not in heading:
334-
raise KeyError(u'`{0:s}` is not in the table heading'.format(field))
335-
elif set(field_list) != set(fields).intersection(heading.names):
336-
raise DataJointError('Attempt to insert rows with different fields')
337-
338-
if isinstance(row, np.void): # np.array
339-
check_fields(row.dtype.fields)
340-
attributes = [make_placeholder(name, row[name])
341-
for name in heading if name in row.dtype.fields]
342-
elif isinstance(row, collections.abc.Mapping): # dict-based
343-
check_fields(row)
344-
attributes = [make_placeholder(name, row[name]) for name in heading if name in row]
345-
else: # positional
346-
try:
347-
if len(row) != len(heading):
348-
raise DataJointError(
349-
'Invalid insert argument. Incorrect number of attributes: '
350-
'{given} given; {expected} expected'.format(
351-
given=len(row), expected=len(heading)))
352-
except TypeError:
353-
raise DataJointError('Datatype %s cannot be inserted' % type(row))
354-
else:
355-
attributes = [make_placeholder(name, value) for name, value in zip(heading, row)]
356-
if ignore_extra_fields:
357-
attributes = [a for a in attributes if a is not None]
358-
359-
assert len(attributes), 'Empty tuple'
360-
row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes)))
361-
nonlocal field_list
362-
if field_list is None:
363-
# first row sets the composition of the field list
364-
field_list = row_to_insert['names']
365-
else:
366-
# reorder attributes in row_to_insert to match field_list
367-
order = list(row_to_insert['names'].index(field) for field in field_list)
368-
row_to_insert['names'] = list(row_to_insert['names'][i] for i in order)
369-
row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order)
370-
row_to_insert['values'] = list(row_to_insert['values'][i] for i in order)
371-
372-
return row_to_insert
373-
374-
rows = list(make_row_to_insert(row) for row in rows)
304+
field_list = [] # collects the field list from first row (passed by reference)
305+
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
375306
if rows:
376307
try:
377308
query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
@@ -638,6 +569,107 @@ def _update(self, attrname, value=None):
638569
where_clause=self.where_clause)
639570
self.connection.query(command, args=(value, ) if value is not None else ())
640571

572+
# --- private helper functions ----
573+
def __make_placeholder(self, name, value, ignore_extra_fields=False):
574+
"""
575+
For a given attribute `name` with `value`, return its processed value or value placeholder
576+
as a string to be included in the query and the value, if any, to be submitted for
577+
processing by mysql API.
578+
:param name: name of attribute to be inserted
579+
:param value: value of attribute to be inserted
580+
"""
581+
if ignore_extra_fields and name not in self.heading:
582+
return None
583+
attr = self.heading[name]
584+
if attr.adapter:
585+
value = attr.adapter.put(value)
586+
if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))):
587+
# set default value
588+
placeholder, value = 'DEFAULT', None
589+
else: # not NULL
590+
placeholder = '%s'
591+
if attr.uuid:
592+
if not isinstance(value, uuid.UUID):
593+
try:
594+
value = uuid.UUID(value)
595+
except (AttributeError, ValueError):
596+
raise DataJointError(
597+
'badly formed UUID value {v} for attribute `{n}`'.format(v=value,
598+
n=name))
599+
value = value.bytes
600+
elif attr.is_blob:
601+
value = blob.pack(value)
602+
value = self.external[attr.store].put(value).bytes if attr.is_external else value
603+
elif attr.is_attachment:
604+
attachment_path = Path(value)
605+
if attr.is_external:
606+
# value is hash of contents
607+
value = self.external[attr.store].upload_attachment(attachment_path).bytes
608+
else:
609+
# value is filename + contents
610+
value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes()
611+
elif attr.is_filepath:
612+
value = self.external[attr.store].upload_filepath(value).bytes
613+
elif attr.numeric:
614+
value = str(int(value) if isinstance(value, bool) else value)
615+
return name, placeholder, value
616+
617+
def __make_row_to_insert(self, row, field_list, ignore_extra_fields):
618+
"""
619+
Helper function for insert and update
620+
:param row: A tuple to insert
621+
:return: a dict with fields 'names', 'placeholders', 'values'
622+
"""
623+
624+
def check_fields(fields):
625+
"""
626+
Validates that all items in `fields` are valid attributes in the heading
627+
:param fields: field names of a tuple
628+
"""
629+
if not field_list:
630+
if not ignore_extra_fields:
631+
for field in fields:
632+
if field not in self.heading:
633+
raise KeyError(u'`{0:s}` is not in the table heading'.format(field))
634+
elif set(field_list) != set(fields).intersection(self.heading.names):
635+
raise DataJointError('Attempt to insert rows with different fields')
636+
637+
if isinstance(row, np.void): # np.array
638+
check_fields(row.dtype.fields)
639+
attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields)
640+
for name in self.heading if name in row.dtype.fields]
641+
elif isinstance(row, collections.abc.Mapping): # dict-based
642+
check_fields(row)
643+
attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields)
644+
for name in self.heading if name in row]
645+
else: # positional
646+
try:
647+
if len(row) != len(self.heading):
648+
raise DataJointError(
649+
'Invalid insert argument. Incorrect number of attributes: '
650+
'{given} given; {expected} expected'.format(
651+
given=len(row), expected=len(self.heading)))
652+
except TypeError:
653+
raise DataJointError('Datatype %s cannot be inserted' % type(row))
654+
else:
655+
attributes = [self.__make_placeholder(name, value, ignore_extra_fields)
656+
for name, value in zip(self.heading, row)]
657+
if ignore_extra_fields:
658+
attributes = [a for a in attributes if a is not None]
659+
660+
assert len(attributes), 'Empty tuple'
661+
row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes)))
662+
if not field_list:
663+
# first row sets the composition of the field list
664+
field_list.extend(row_to_insert['names'])
665+
else:
666+
# reorder attributes in row_to_insert to match field_list
667+
order = list(row_to_insert['names'].index(field) for field in field_list)
668+
row_to_insert['names'] = list(row_to_insert['names'][i] for i in order)
669+
row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order)
670+
row_to_insert['values'] = list(row_to_insert['values'][i] for i in order)
671+
return row_to_insert
672+
641673

642674
def lookup_class_name(name, context, depth=3):
643675
"""

datajoint/user_tables.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
# attributes that trigger instantiation of user classes
1414
supported_class_attrs = {
15-
'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
16-
'fetch', 'fetch1', 'head', 'tail',
15+
'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key',
16+
'proj', 'aggr', 'fetch', 'fetch1', 'head', 'tail',
1717
'descendants', 'parts', 'ancestors', 'parents', 'children',
18-
'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
18+
'insert', 'insert1', 'update1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
1919

2020

2121
class OrderedClass(type):

tests/test_attach.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from nose.tools import assert_true, assert_equal, assert_not_equal
2-
from numpy.testing import assert_array_equal
32
import tempfile
43
from pathlib import Path
54
import os

tests/test_fetch_same.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import datajoint as dj
55

66
schema = dj.Schema(PREFIX + '_fetch_same', connection=dj.conn(**CONN_INFO))
7+
dj.config['enable_python_native_blobs'] = True
78

89

910
@schema

tests/test_reconnection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Collection of test cases to test connection module.
33
"""
44

5-
from nose.tools import assert_true, assert_false, assert_equal, raises
5+
from nose.tools import assert_true, assert_false, raises
66
import datajoint as dj
77
from datajoint import DataJointError
88
from . import CONN_INFO

0 commit comments

Comments
 (0)